From bccad5bbfc3e454c8b54527116900cc8555be3aa Mon Sep 17 00:00:00 2001 From: Kristen O'Leary Date: Fri, 16 Dec 2022 12:43:30 -0500 Subject: [PATCH] feat: add x-goog-request-params to header --- .../google/datastore/v1/client/Datastore.java | 26 +++++++---- .../google/datastore/v1/client/RemoteRpc.java | 19 ++++++-- .../datastore/v1/client/RemoteRpcTest.java | 45 ++++++++++++++++--- 3 files changed, 74 insertions(+), 16 deletions(-) diff --git a/datastore-v1-proto-client/src/main/java/com/google/datastore/v1/client/Datastore.java b/datastore-v1-proto-client/src/main/java/com/google/datastore/v1/client/Datastore.java index 09101c94b..6b886aef0 100644 --- a/datastore-v1-proto-client/src/main/java/com/google/datastore/v1/client/Datastore.java +++ b/datastore-v1-proto-client/src/main/java/com/google/datastore/v1/client/Datastore.java @@ -67,7 +67,8 @@ private DatastoreException invalidResponseException(String method, IOException e } public AllocateIdsResponse allocateIds(AllocateIdsRequest request) throws DatastoreException { - try (InputStream is = remoteRpc.call("allocateIds", request)) { + try (InputStream is = + remoteRpc.call("allocateIds", request, request.getProjectId(), request.getDatabaseId())) { return AllocateIdsResponse.parseFrom(is); } catch (IOException exception) { throw invalidResponseException("allocateIds", exception); @@ -76,7 +77,9 @@ public AllocateIdsResponse allocateIds(AllocateIdsRequest request) throws Datast public BeginTransactionResponse beginTransaction(BeginTransactionRequest request) throws DatastoreException { - try (InputStream is = remoteRpc.call("beginTransaction", request)) { + try (InputStream is = + remoteRpc.call( + "beginTransaction", request, request.getProjectId(), request.getDatabaseId())) { return BeginTransactionResponse.parseFrom(is); } catch (IOException exception) { throw invalidResponseException("beginTransaction", exception); @@ -84,7 +87,8 @@ public BeginTransactionResponse beginTransaction(BeginTransactionRequest request } public CommitResponse commit(CommitRequest request) throws DatastoreException { - try (InputStream is = remoteRpc.call("commit", request)) { + try (InputStream is = + remoteRpc.call("commit", request, request.getProjectId(), request.getDatabaseId())) { return CommitResponse.parseFrom(is); } catch (IOException exception) { throw invalidResponseException("commit", exception); @@ -92,7 +96,8 @@ public CommitResponse commit(CommitRequest request) throws DatastoreException { } public LookupResponse lookup(LookupRequest request) throws DatastoreException { - try (InputStream is = remoteRpc.call("lookup", request)) { + try (InputStream is = + remoteRpc.call("lookup", request, request.getProjectId(), request.getDatabaseId())) { return LookupResponse.parseFrom(is); } catch (IOException exception) { throw invalidResponseException("lookup", exception); @@ -100,7 +105,8 @@ public LookupResponse lookup(LookupRequest request) throws DatastoreException { } public ReserveIdsResponse reserveIds(ReserveIdsRequest request) throws DatastoreException { - try (InputStream is = remoteRpc.call("reserveIds", request)) { + try (InputStream is = + remoteRpc.call("reserveIds", request, request.getProjectId(), request.getDatabaseId())) { return ReserveIdsResponse.parseFrom(is); } catch (IOException exception) { throw invalidResponseException("reserveIds", exception); @@ -108,7 +114,8 @@ public ReserveIdsResponse reserveIds(ReserveIdsRequest request) throws Datastore } public RollbackResponse rollback(RollbackRequest request) throws DatastoreException { - try (InputStream is = remoteRpc.call("rollback", request)) { + try (InputStream is = + remoteRpc.call("rollback", request, request.getProjectId(), request.getDatabaseId())) { return RollbackResponse.parseFrom(is); } catch (IOException exception) { throw invalidResponseException("rollback", exception); @@ -116,7 +123,8 @@ public RollbackResponse rollback(RollbackRequest request) throws DatastoreExcept } public RunQueryResponse runQuery(RunQueryRequest request) throws DatastoreException { - try (InputStream is = remoteRpc.call("runQuery", request)) { + try (InputStream is = + remoteRpc.call("runQuery", request, request.getProjectId(), request.getDatabaseId())) { return RunQueryResponse.parseFrom(is); } catch (IOException exception) { throw invalidResponseException("runQuery", exception); @@ -125,7 +133,9 @@ public RunQueryResponse runQuery(RunQueryRequest request) throws DatastoreExcept public RunAggregationQueryResponse runAggregationQuery(RunAggregationQueryRequest request) throws DatastoreException { - try (InputStream is = remoteRpc.call("runAggregationQuery", request)) { + try (InputStream is = + remoteRpc.call( + "runAggregationQuery", request, request.getProjectId(), request.getDatabaseId())) { return RunAggregationQueryResponse.parseFrom(is); } catch (IOException exception) { throw invalidResponseException("runAggregationQuery", exception); diff --git a/datastore-v1-proto-client/src/main/java/com/google/datastore/v1/client/RemoteRpc.java b/datastore-v1-proto-client/src/main/java/com/google/datastore/v1/client/RemoteRpc.java index 321eea72a..b0b47c505 100644 --- a/datastore-v1-proto-client/src/main/java/com/google/datastore/v1/client/RemoteRpc.java +++ b/datastore-v1-proto-client/src/main/java/com/google/datastore/v1/client/RemoteRpc.java @@ -24,6 +24,7 @@ import com.google.api.client.http.protobuf.ProtoHttpContent; import com.google.api.client.util.IOUtils; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; import com.google.protobuf.MessageLite; import com.google.rpc.Code; import com.google.rpc.Status; @@ -46,6 +47,8 @@ class RemoteRpc { @VisibleForTesting static final String API_FORMAT_VERSION_HEADER = "X-Goog-Api-Format-Version"; private static final String API_FORMAT_VERSION = "2"; + @VisibleForTesting static final String X_GOOG_REQUEST_PARAMS_HEADER = "x-goog-request-params"; + private final HttpRequestFactory client; private final HttpRequestInitializer initializer; private final String url; @@ -74,7 +77,9 @@ class RemoteRpc { * * @throws DatastoreException if the RPC fails. */ - public InputStream call(String methodName, MessageLite request) throws DatastoreException { + public InputStream call( + String methodName, MessageLite request, String projectId, String databaseId) + throws DatastoreException { logger.fine("remote datastore call " + methodName); long startTime = System.currentTimeMillis(); @@ -84,7 +89,7 @@ public InputStream call(String methodName, MessageLite request) throws Datastore rpcCount.incrementAndGet(); ProtoHttpContent payload = new ProtoHttpContent(request); HttpRequest httpRequest = client.buildPostRequest(resolveURL(methodName), payload); - setHeaders(request, httpRequest); + setHeaders(request, httpRequest, projectId, databaseId); // Don't throw an HTTPResponseException on error. It converts the response to a String and // throws away the original, whereas we need the raw bytes to parse it as a proto. httpRequest.setThrowExceptionOnExecuteError(false); @@ -123,8 +128,16 @@ public InputStream call(String methodName, MessageLite request) throws Datastore } @VisibleForTesting - void setHeaders(MessageLite request, HttpRequest httpRequest) { + void setHeaders( + MessageLite request, HttpRequest httpRequest, String projectId, String databaseId) { httpRequest.getHeaders().put(API_FORMAT_VERSION_HEADER, API_FORMAT_VERSION); + StringBuilder builder = new StringBuilder("project_id="); + builder.append(projectId); + if (!Strings.isNullOrEmpty(databaseId)) { + builder.append("&database_id="); + builder.append(databaseId); + } + httpRequest.getHeaders().put(X_GOOG_REQUEST_PARAMS_HEADER, builder.toString()); if (enableE2EChecksum && request != null) { String checksum = EndToEndChecksumHandler.computeChecksum(request.toByteArray()); if (checksum != null) { diff --git a/datastore-v1-proto-client/src/test/java/com/google/datastore/v1/client/RemoteRpcTest.java b/datastore-v1-proto-client/src/test/java/com/google/datastore/v1/client/RemoteRpcTest.java index ebcb12396..281e92f04 100644 --- a/datastore-v1-proto-client/src/test/java/com/google/datastore/v1/client/RemoteRpcTest.java +++ b/datastore-v1-proto-client/src/test/java/com/google/datastore/v1/client/RemoteRpcTest.java @@ -146,7 +146,8 @@ public void testGzip() throws IOException, DatastoreException { new InjectedTestValues(gzip(response), new byte[1], true); RemoteRpc rpc = newRemoteRpc(injectedTestValues); - InputStream is = rpc.call("beginTransaction", BeginTransactionResponse.getDefaultInstance()); + InputStream is = + rpc.call("beginTransaction", BeginTransactionResponse.getDefaultInstance(), "", ""); BeginTransactionResponse parsedResponse = BeginTransactionResponse.parseFrom(is); is.close(); @@ -159,14 +160,15 @@ public void testGzip() throws IOException, DatastoreException { public void testHttpHeaders_expectE2eChecksumHeader() throws IOException { // Enable E2E-Checksum system env variable RemoteRpc.setSystemEnvE2EChecksum(true); + String projectId = "project-id"; MessageLite request = - RollbackRequest.newBuilder().setTransaction(ByteString.copyFromUtf8("project-id")).build(); + RollbackRequest.newBuilder().setTransaction(ByteString.copyFromUtf8(projectId)).build(); RemoteRpc rpc = newRemoteRpc( new InjectedTestValues(gzip(newBeginTransactionResponse()), new byte[1], true)); HttpRequest httpRequest = rpc.getClient().buildPostRequest(rpc.resolveURL("blah"), new ProtoHttpContent(request)); - rpc.setHeaders(request, httpRequest); + rpc.setHeaders(request, httpRequest, projectId, ""); assertNotNull( httpRequest.getHeaders().getFirstHeaderStringValue(RemoteRpc.API_FORMAT_VERSION_HEADER)); // Expect to find e2e-checksum header @@ -181,14 +183,15 @@ public void testHttpHeaders_expectE2eChecksumHeader() throws IOException { public void testHttpHeaders_doNotExpectE2eChecksumHeader() throws IOException { // disable E2E-Checksum system env variable RemoteRpc.setSystemEnvE2EChecksum(false); + String projectId = "project-id"; MessageLite request = - RollbackRequest.newBuilder().setTransaction(ByteString.copyFromUtf8("project-id")).build(); + RollbackRequest.newBuilder().setTransaction(ByteString.copyFromUtf8(projectId)).build(); RemoteRpc rpc = newRemoteRpc( new InjectedTestValues(gzip(newBeginTransactionResponse()), new byte[1], true)); HttpRequest httpRequest = rpc.getClient().buildPostRequest(rpc.resolveURL("blah"), new ProtoHttpContent(request)); - rpc.setHeaders(request, httpRequest); + rpc.setHeaders(request, httpRequest, projectId, ""); assertNotNull( httpRequest.getHeaders().getFirstHeaderStringValue(RemoteRpc.API_FORMAT_VERSION_HEADER)); // Do not expect to find e2e-checksum header @@ -198,6 +201,38 @@ public void testHttpHeaders_doNotExpectE2eChecksumHeader() throws IOException { .getFirstHeaderStringValue(EndToEndChecksumHandler.HTTP_REQUEST_CHECKSUM_HEADER)); } + @Test + public void testHttpHeaders_prefixHeader() throws IOException { + String projectId = "my-project"; + String databaseId = "my-db"; + MessageLite request = + RollbackRequest.newBuilder() + .setTransaction(ByteString.copyFromUtf8(projectId)) + .setDatabaseId(databaseId) + .build(); + RemoteRpc rpc = + newRemoteRpc( + new InjectedTestValues(gzip(newBeginTransactionResponse()), new byte[1], true)); + HttpRequest httpRequest = + rpc.getClient().buildPostRequest(rpc.resolveURL("blah"), new ProtoHttpContent(request)); + rpc.setHeaders(request, httpRequest, projectId, databaseId); + assertEquals( + "project_id=my-project&database_id=my-db", + httpRequest.getHeaders().get(RemoteRpc.X_GOOG_REQUEST_PARAMS_HEADER)); + + MessageLite request2 = + RollbackRequest.newBuilder().setTransaction(ByteString.copyFromUtf8(projectId)).build(); + RemoteRpc rpc2 = + newRemoteRpc( + new InjectedTestValues(gzip(newBeginTransactionResponse()), new byte[1], true)); + HttpRequest httpRequest2 = + rpc2.getClient().buildPostRequest(rpc2.resolveURL("blah"), new ProtoHttpContent(request2)); + rpc2.setHeaders(request, httpRequest2, projectId, ""); + assertEquals( + "project_id=my-project", + httpRequest2.getHeaders().get(RemoteRpc.X_GOOG_REQUEST_PARAMS_HEADER)); + } + private static BeginTransactionResponse newBeginTransactionResponse() { return BeginTransactionResponse.newBuilder() .setTransaction(ByteString.copyFromUtf8("blah-blah-blah"))