From 46625729b6fd62d8f133c3fb2d8ee00eb64ee8e9 Mon Sep 17 00:00:00 2001 From: BenWhitehead Date: Fri, 16 Dec 2022 14:04:14 -0500 Subject: [PATCH] fix: update Grpc Write implementation to allow specifying expected md5 (#1815) Remove Hasher.Constant. StartResumableWriteRequest has been updated to allow specifying `object_checksums` when creating the session. Add several new positive and negative integration test for md5 verification --- .../clirr-ignored-differences.xml | 5 ++ .../storage/GapicUploadSessionBuilder.java | 11 ++- .../google/cloud/storage/GrpcStorageImpl.java | 32 ++++--- .../java/com/google/cloud/storage/Hasher.java | 33 -------- .../google/cloud/storage/ResumableWrite.java | 11 ++- .../com/google/cloud/storage/Storage.java | 2 +- .../com/google/cloud/storage/UnifiedOpts.java | 9 ++ .../cloud/storage/WriteFlushStrategy.java | 34 ++++++-- ...apicUnbufferedWritableByteChannelTest.java | 2 +- .../com/google/cloud/storage/TestUtils.java | 11 +++ .../it/ITObjectChecksumSupportTest.java | 83 ++++++++++++++++--- 11 files changed, 157 insertions(+), 76 deletions(-) diff --git a/google-cloud-storage/clirr-ignored-differences.xml b/google-cloud-storage/clirr-ignored-differences.xml index e166707532..7182f2a962 100644 --- a/google-cloud-storage/clirr-ignored-differences.xml +++ b/google-cloud-storage/clirr-ignored-differences.xml @@ -2,4 +2,9 @@ + + 8001 + com/google/cloud/storage/Hasher$ConstantConcatValueHasher + + diff --git a/google-cloud-storage/src/main/java/com/google/cloud/storage/GapicUploadSessionBuilder.java b/google-cloud-storage/src/main/java/com/google/cloud/storage/GapicUploadSessionBuilder.java index 988ab936cd..7b4c6a949a 100644 --- a/google-cloud-storage/src/main/java/com/google/cloud/storage/GapicUploadSessionBuilder.java +++ b/google-cloud-storage/src/main/java/com/google/cloud/storage/GapicUploadSessionBuilder.java @@ -25,6 +25,7 @@ import com.google.storage.v2.StartResumableWriteResponse; import com.google.storage.v2.WriteObjectRequest; import com.google.storage.v2.WriteObjectResponse; +import java.util.function.Function; final class GapicUploadSessionBuilder { @@ -49,8 +50,16 @@ ApiFuture resumableWrite( if (writeObjectRequest.hasCommonObjectRequestParams()) { b.setCommonObjectRequestParams(writeObjectRequest.getCommonObjectRequestParams()); } + if (writeObjectRequest.hasObjectChecksums()) { + b.setObjectChecksums(writeObjectRequest.getObjectChecksums()); + } StartResumableWriteRequest req = b.build(); + Function f = + uploadId -> + writeObjectRequest.toBuilder().clearWriteObjectSpec().setUploadId(uploadId).build(); return ApiFutures.transform( - x.futureCall(req), (resp) -> new ResumableWrite(req, resp), MoreExecutors.directExecutor()); + x.futureCall(req), + (resp) -> new ResumableWrite(req, resp, f), + MoreExecutors.directExecutor()); } } diff --git a/google-cloud-storage/src/main/java/com/google/cloud/storage/GrpcStorageImpl.java b/google-cloud-storage/src/main/java/com/google/cloud/storage/GrpcStorageImpl.java index bad2cccb72..10f60230c2 100644 --- a/google-cloud-storage/src/main/java/com/google/cloud/storage/GrpcStorageImpl.java +++ b/google-cloud-storage/src/main/java/com/google/cloud/storage/GrpcStorageImpl.java @@ -95,6 +95,7 @@ import com.google.storage.v2.LockBucketRetentionPolicyRequest; import com.google.storage.v2.Object; import com.google.storage.v2.ObjectAccessControl; +import com.google.storage.v2.ObjectChecksums; import com.google.storage.v2.ProjectName; import com.google.storage.v2.ReadObjectRequest; import com.google.storage.v2.RewriteObjectRequest; @@ -221,6 +222,7 @@ public Blob create( GrpcCallContext grpcCallContext = opts.grpcMetadataMapper().apply(GrpcCallContext.createDefault()); WriteObjectRequest req = getWriteObjectRequest(blobInfo, opts); + Hasher hasher = getHasherForRequest(req, Hasher.enabled()); return Retrying.run( getOptions(), retryAlgorithmManager.getFor(req), @@ -231,7 +233,7 @@ public Blob create( .byteChannel( storageClient.writeObjectCallable().withDefaultCallContext(grpcCallContext)) .setByteStringStrategy(ByteStringStrategy.noCopy()) - .setHasher(Hasher.enabled()) + .setHasher(hasher) .direct() .unbuffered() .setRequest(req) @@ -273,10 +275,7 @@ public Blob createFrom(BlobInfo blobInfo, Path path, int bufferSize, BlobWriteOp opts.grpcMetadataMapper().apply(GrpcCallContext.createDefault()); WriteObjectRequest req = getWriteObjectRequest(blobInfo, opts); - Hasher hasher = Hasher.enabled(); - if (req.hasObjectChecksums() && req.getObjectChecksums().hasCrc32C()) { - hasher = Hasher.constant(req.getObjectChecksums().getCrc32C()); - } + Hasher hasher = getHasherForRequest(req, Hasher.enabled()); GapicWritableByteChannelSessionBuilder channelSessionBuilder = ResumableMedia.gapic() .write() @@ -346,10 +345,7 @@ public Blob createFrom( ApiFuture start = startResumableWrite(grpcCallContext, req); - Hasher hasher = Hasher.enabled(); - if (req.hasObjectChecksums() && req.getObjectChecksums().hasCrc32C()) { - hasher = Hasher.constant(req.getObjectChecksums().getCrc32C()); - } + Hasher hasher = getHasherForRequest(req, Hasher.enabled()); BufferedWritableByteChannelSession session = ResumableMedia.gapic() .write() @@ -736,10 +732,7 @@ public GrpcBlobWriteChannel writer(BlobInfo blobInfo, BlobWriteOption... options GrpcCallContext grpcCallContext = opts.grpcMetadataMapper().apply(GrpcCallContext.createDefault()); WriteObjectRequest req = getWriteObjectRequest(blobInfo, opts); - Hasher hasher = Hasher.noop(); - if (req.hasObjectChecksums() && req.getObjectChecksums().hasCrc32C()) { - hasher = Hasher.constant(req.getObjectChecksums().getCrc32C()); - } + Hasher hasher = getHasherForRequest(req, Hasher.enabled()); return new GrpcBlobWriteChannel( storageClient.writeObjectCallable(), getOptions(), @@ -1789,4 +1782,17 @@ private Object updateObject(UpdateObjectRequest req) { () -> storageClient.updateObjectCallable().call(req, grpcCallContext), Decoder.identity()); } + + private static Hasher getHasherForRequest(WriteObjectRequest req, Hasher defaultHasher) { + if (!req.hasObjectChecksums()) { + return defaultHasher; + } else { + ObjectChecksums checksums = req.getObjectChecksums(); + if (!checksums.hasCrc32C() && checksums.getMd5Hash().isEmpty()) { + return defaultHasher; + } else { + return Hasher.noop(); + } + } + } } diff --git a/google-cloud-storage/src/main/java/com/google/cloud/storage/Hasher.java b/google-cloud-storage/src/main/java/com/google/cloud/storage/Hasher.java index 5c29ce26e4..06fca0413f 100644 --- a/google-cloud-storage/src/main/java/com/google/cloud/storage/Hasher.java +++ b/google-cloud-storage/src/main/java/com/google/cloud/storage/Hasher.java @@ -47,17 +47,6 @@ static Hasher enabled() { return GuavaHasher.INSTANCE; } - /** - * Create a Hasher which will always yield the specified value when {@link - * #nullSafeConcat(Crc32cLengthKnown, Crc32cLengthKnown)} is invoked. - */ - // Not perfect, and not a great approach for a public API. However, this is the most pragmatic way - // right now to wire an externally defined value all the way down to the last write message of a - // resumable upload session. - static Hasher constant(int crc32c) { - return new ConstantConcatValueHasher(Crc32cValue.of(crc32c, -1)); - } - @Immutable class NoOpHasher implements Hasher { private static final NoOpHasher INSTANCE = new NoOpHasher(); @@ -112,26 +101,4 @@ public Crc32cLengthKnown nullSafeConcat(Crc32cLengthKnown r1, Crc32cLengthKnown } } } - - @Immutable - class ConstantConcatValueHasher implements Hasher { - private final Crc32cLengthKnown value; - - private ConstantConcatValueHasher(Crc32cLengthKnown value) { - this.value = value; - } - - @Override - public @Nullable Crc32cLengthKnown hash(ByteBuffer b) { - return null; - } - - @Override - public void validate(Crc32cValue expected, Supplier b) {} - - @Override - public @Nullable Crc32cLengthKnown nullSafeConcat(Crc32cLengthKnown r1, Crc32cLengthKnown r2) { - return value; - } - } } diff --git a/google-cloud-storage/src/main/java/com/google/cloud/storage/ResumableWrite.java b/google-cloud-storage/src/main/java/com/google/cloud/storage/ResumableWrite.java index 47996e7bab..75921032de 100644 --- a/google-cloud-storage/src/main/java/com/google/cloud/storage/ResumableWrite.java +++ b/google-cloud-storage/src/main/java/com/google/cloud/storage/ResumableWrite.java @@ -33,14 +33,13 @@ final class ResumableWrite implements WriteObjectRequestBuilderFactory { private final WriteObjectRequest writeRequest; - public ResumableWrite(StartResumableWriteRequest req, StartResumableWriteResponse res) { + public ResumableWrite( + StartResumableWriteRequest req, + StartResumableWriteResponse res, + Function f) { this.req = req; this.res = res; - WriteObjectRequest.Builder b = WriteObjectRequest.newBuilder().setUploadId(res.getUploadId()); - if (req.hasCommonObjectRequestParams()) { - b.setCommonObjectRequestParams(req.getCommonObjectRequestParams()); - } - this.writeRequest = b.build(); + this.writeRequest = f.apply(res.getUploadId()); } public StartResumableWriteRequest getReq() { diff --git a/google-cloud-storage/src/main/java/com/google/cloud/storage/Storage.java b/google-cloud-storage/src/main/java/com/google/cloud/storage/Storage.java index 67f1121051..d281d3a332 100644 --- a/google-cloud-storage/src/main/java/com/google/cloud/storage/Storage.java +++ b/google-cloud-storage/src/main/java/com/google/cloud/storage/Storage.java @@ -774,7 +774,7 @@ public static BlobWriteOption metagenerationNotMatch() { * @deprecated Please compute and use a crc32c checksum instead. {@link #crc32cMatch()} */ @Deprecated - @TransportCompatibility(Transport.HTTP) + @TransportCompatibility({Transport.HTTP, Transport.GRPC}) public static BlobWriteOption md5Match() { return new BlobWriteOption(UnifiedOpts.md5MatchExtractor()); } diff --git a/google-cloud-storage/src/main/java/com/google/cloud/storage/UnifiedOpts.java b/google-cloud-storage/src/main/java/com/google/cloud/storage/UnifiedOpts.java index 3df773d08a..6eecc74fd4 100644 --- a/google-cloud-storage/src/main/java/com/google/cloud/storage/UnifiedOpts.java +++ b/google-cloud-storage/src/main/java/com/google/cloud/storage/UnifiedOpts.java @@ -1096,6 +1096,15 @@ public boolean equals(Object o) { return Objects.equals(val, md5Match.val); } + @Override + public Mapper writeObject() { + return b -> { + b.getObjectChecksumsBuilder() + .setMd5Hash(ByteString.copyFrom(BaseEncoding.base64().decode(val))); + return b; + }; + } + @Override public int hashCode() { return Objects.hash(val); diff --git a/google-cloud-storage/src/main/java/com/google/cloud/storage/WriteFlushStrategy.java b/google-cloud-storage/src/main/java/com/google/cloud/storage/WriteFlushStrategy.java index 0007ec684c..7deb004342 100644 --- a/google-cloud-storage/src/main/java/com/google/cloud/storage/WriteFlushStrategy.java +++ b/google-cloud-storage/src/main/java/com/google/cloud/storage/WriteFlushStrategy.java @@ -89,6 +89,28 @@ private static GrpcCallContext contextWithBucketName(String bucketName) { return ret; } + /** + * Several fields of a WriteObjectRequest are only allowed on the "first" message sent to gcs, + * this utility method centralizes the logic necessary to clear those fields for use by subsequent + * messages. + */ + private static WriteObjectRequest possiblyPairDownRequest( + WriteObjectRequest message, boolean firstMessageOfStream) { + if (firstMessageOfStream && message.getWriteOffset() == 0) { + return message; + } + + WriteObjectRequest.Builder b = message.toBuilder(); + if (!firstMessageOfStream) { + b.clearUploadId(); + } + + if (message.getWriteOffset() > 0) { + b.clearWriteObjectSpec().clearObjectChecksums(); + } + return b.build(); + } + @FunctionalInterface interface FlusherFactory { /** @@ -144,9 +166,7 @@ public void flush(@NonNull List segments) { boolean first = true; for (WriteObjectRequest message : segments) { - if (!first) { - message = message.toBuilder().clearUploadId().clearWriteObjectSpec().build(); - } + message = possiblyPairDownRequest(message, first); write.onNext(message); first = false; @@ -188,9 +208,7 @@ private FsyncOnClose( public void flush(@NonNull List segments) { ensureOpen(); for (WriteObjectRequest message : segments) { - if (!first) { - message = message.toBuilder().clearUploadId().clearWriteObjectSpec().build(); - } + message = possiblyPairDownRequest(message, first); stream.onNext(message); first = false; @@ -201,9 +219,7 @@ public void flush(@NonNull List segments) { public void close(@Nullable WriteObjectRequest message) { ensureOpen(); if (message != null) { - if (!first) { - message = message.toBuilder().clearUploadId().clearWriteObjectSpec().build(); - } + message = possiblyPairDownRequest(message, first); stream.onNext(message); } stream.onCompleted(); diff --git a/google-cloud-storage/src/test/java/com/google/cloud/storage/GapicUnbufferedWritableByteChannelTest.java b/google-cloud-storage/src/test/java/com/google/cloud/storage/GapicUnbufferedWritableByteChannelTest.java index 80331e2366..306b867d21 100644 --- a/google-cloud-storage/src/test/java/com/google/cloud/storage/GapicUnbufferedWritableByteChannelTest.java +++ b/google-cloud-storage/src/test/java/com/google/cloud/storage/GapicUnbufferedWritableByteChannelTest.java @@ -105,7 +105,7 @@ public final class GapicUnbufferedWritableByteChannelTest { WriteObjectResponse.newBuilder().setResource(obj.toBuilder().setSize(40)).build(); private static final WriteObjectRequestBuilderFactory reqFactory = - new ResumableWrite(startReq, startResp); + new ResumableWrite(startReq, startResp, TestUtils.onlyUploadId()); @Test public void directUpload() throws IOException, InterruptedException, ExecutionException { diff --git a/google-cloud-storage/src/test/java/com/google/cloud/storage/TestUtils.java b/google-cloud-storage/src/test/java/com/google/cloud/storage/TestUtils.java index 1810060c75..7e5ac8c778 100644 --- a/google-cloud-storage/src/test/java/com/google/cloud/storage/TestUtils.java +++ b/google-cloud-storage/src/test/java/com/google/cloud/storage/TestUtils.java @@ -37,6 +37,7 @@ import com.google.protobuf.ByteString; import com.google.rpc.DebugInfo; import com.google.storage.v2.ChecksummedData; +import com.google.storage.v2.WriteObjectRequest; import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; import java.io.ByteArrayOutputStream; @@ -49,6 +50,7 @@ import java.util.stream.IntStream; import java.util.stream.Stream; import java.util.zip.GZIPOutputStream; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; public final class TestUtils { @@ -179,4 +181,13 @@ public boolean shouldRetry(Throwable previousThrowable, Object previousResponse) } } } + + /** + * Return a function which when provided an {@code uploadId} will create a {@link + * WriteObjectRequest} with that {@code uploadId} + */ + @NonNull + public static Function onlyUploadId() { + return uId -> WriteObjectRequest.newBuilder().setUploadId(uId).build(); + } } diff --git a/google-cloud-storage/src/test/java/com/google/cloud/storage/it/ITObjectChecksumSupportTest.java b/google-cloud-storage/src/test/java/com/google/cloud/storage/it/ITObjectChecksumSupportTest.java index 129deebd38..5538626d25 100644 --- a/google-cloud-storage/src/test/java/com/google/cloud/storage/it/ITObjectChecksumSupportTest.java +++ b/google-cloud-storage/src/test/java/com/google/cloud/storage/it/ITObjectChecksumSupportTest.java @@ -18,7 +18,6 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; -import static org.junit.Assert.fail; import com.google.cloud.WriteChannel; import com.google.cloud.storage.Blob; @@ -61,6 +60,8 @@ public final class ITObjectChecksumSupportTest { @Inject public Storage storage; @Inject public BucketInfo bucket; + @Inject public Transport transport; + @Parameter public ChecksummedTestContent content; public static final class ChecksummedTestContentProvider implements ParametersProvider { @@ -158,18 +159,76 @@ public void testCrc32cValidated_writer_expectSuccess() throws IOException { } @Test - // Error Handling for GRPC not complete b/247621346 - @CrossRun.Exclude(transports = Transport.GRPC) - public void testCreateBlobMd5Fail() { + public void testMd5Validated_createFrom_expectFailure() { + String blobName = testName.getMethodName(); + BlobId blobId = BlobId.of(bucket.getName(), blobName); + BlobInfo blobInfo = BlobInfo.newBuilder(blobId).setMd5(content.getMd5Base64()).build(); + + byte[] bytes = content.concat('x'); + StorageException expected = + assertThrows( + StorageException.class, + () -> + storage.createFrom( + blobInfo, + new ByteArrayInputStream(bytes), + BlobWriteOption.doesNotExist(), + BlobWriteOption.md5Match())); + assertThat(expected.getCode()).isEqualTo(400); + } + + @Test + public void testMd5Validated_createFrom_expectSuccess() throws IOException { + String blobName = testName.getMethodName(); + BlobId blobId = BlobId.of(bucket.getName(), blobName); + BlobInfo blobInfo = BlobInfo.newBuilder(blobId).setMd5(content.getMd5Base64()).build(); + + byte[] bytes = content.getBytes(); + Blob blob = + storage.createFrom( + blobInfo, + new ByteArrayInputStream(bytes), + BlobWriteOption.doesNotExist(), + BlobWriteOption.md5Match()); + assertThat(blob.getMd5()).isEqualTo(content.getMd5Base64()); + } + + @Test + public void testMd5Validated_writer_expectFailure() { + String blobName = testName.getMethodName(); + BlobId blobId = BlobId.of(bucket.getName(), blobName); + BlobInfo blobInfo = BlobInfo.newBuilder(blobId).setMd5(content.getMd5Base64()).build(); + + byte[] bytes = content.concat('x'); + StorageException expected = + assertThrows( + StorageException.class, + () -> { + try (ReadableByteChannel src = Channels.newChannel(new ByteArrayInputStream(bytes)); + WriteChannel dst = + storage.writer( + blobInfo, BlobWriteOption.doesNotExist(), BlobWriteOption.md5Match())) { + ByteStreams.copy(src, dst); + } + }); + assertThat(expected.getCode()).isEqualTo(400); + } + + @Test + public void testMd5Validated_writer_expectSuccess() throws IOException { String blobName = testName.getMethodName(); - BlobInfo blob = - BlobInfo.newBuilder(bucket, blobName).setMd5("O1R4G1HJSDUISJjoIYmVhQ==").build(); - ByteArrayInputStream stream = content.bytesAsInputStream(); - try { - storage.create(blob, stream, Storage.BlobWriteOption.md5Match()); - fail("StorageException was expected"); - } catch (StorageException ex) { - // expected + BlobId blobId = BlobId.of(bucket.getName(), blobName); + BlobInfo blobInfo = BlobInfo.newBuilder(blobId).setMd5(content.getMd5Base64()).build(); + + byte[] bytes = content.getBytes(); + + try (ReadableByteChannel src = Channels.newChannel(new ByteArrayInputStream(bytes)); + WriteChannel dst = + storage.writer(blobInfo, BlobWriteOption.doesNotExist(), BlobWriteOption.md5Match())) { + ByteStreams.copy(src, dst); } + + Blob blob = storage.get(blobId); + assertThat(blob.getMd5()).isEqualTo(content.getMd5Base64()); } }