Skip to content

Commit c1d1f4a

Browse files
authored
fix: add strict client side response validation for gRPC chunked resumable uploads (#2527)
* Rename JsonResumableSessionFailureScenario to ResumableSessionFailureScenario. The failure scenarios themselves are not json specific, and the methods which are json specific can have grpc overloads * Add more tests to validate GapicUnbufferedChunkedResumableWritableByteChannel is able to properly detect and handle various success responses from GCS which are not success for the client.
1 parent 09f7191 commit c1d1f4a

11 files changed

+969
-122
lines changed

google-cloud-storage/src/main/java/com/google/cloud/storage/GapicUnbufferedChunkedResumableWritableByteChannel.java

+90-68
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import com.google.api.gax.retrying.ResultRetryAlgorithm;
2424
import com.google.api.gax.rpc.ApiStreamObserver;
2525
import com.google.api.gax.rpc.ClientStreamingCallable;
26+
import com.google.api.gax.rpc.OutOfRangeException;
2627
import com.google.cloud.storage.ChunkSegmenter.ChunkSegment;
2728
import com.google.cloud.storage.Conversions.Decoder;
2829
import com.google.cloud.storage.Crc32cValue.Crc32cLengthKnown;
@@ -41,10 +42,9 @@
4142
import java.util.ArrayList;
4243
import java.util.List;
4344
import java.util.concurrent.ExecutionException;
44-
import java.util.function.Consumer;
45-
import java.util.function.LongConsumer;
4645
import java.util.function.Supplier;
4746
import org.checkerframework.checker.nullness.qual.NonNull;
47+
import org.checkerframework.checker.nullness.qual.Nullable;
4848

4949
final class GapicUnbufferedChunkedResumableWritableByteChannel
5050
implements UnbufferedWritableByteChannel {
@@ -58,30 +58,26 @@ final class GapicUnbufferedChunkedResumableWritableByteChannel
5858
private final RetryingDependencies deps;
5959
private final ResultRetryAlgorithm<?> alg;
6060
private final Supplier<GrpcCallContext> baseContextSupplier;
61-
private final LongConsumer sizeCallback;
62-
private final Consumer<WriteObjectResponse> completeCallback;
6361

64-
private boolean open = true;
62+
private volatile boolean open = true;
6563
private boolean finished = false;
6664

6765
GapicUnbufferedChunkedResumableWritableByteChannel(
6866
SettableApiFuture<WriteObjectResponse> resultFuture,
6967
@NonNull ChunkSegmenter chunkSegmenter,
7068
ClientStreamingCallable<WriteObjectRequest, WriteObjectResponse> write,
71-
ResumableWrite requestFactory,
69+
WriteCtx<ResumableWrite> writeCtx,
7270
RetryingDependencies deps,
7371
ResultRetryAlgorithm<?> alg,
7472
Supplier<GrpcCallContext> baseContextSupplier) {
7573
this.resultFuture = resultFuture;
7674
this.chunkSegmenter = chunkSegmenter;
7775
this.write = write;
78-
this.bucketName = requestFactory.bucketName();
79-
this.writeCtx = new WriteCtx<>(requestFactory);
76+
this.bucketName = writeCtx.getRequestFactory().bucketName();
77+
this.writeCtx = writeCtx;
8078
this.deps = deps;
8179
this.alg = alg;
8280
this.baseContextSupplier = baseContextSupplier;
83-
this.sizeCallback = writeCtx.getConfirmedBytes()::set;
84-
this.completeCallback = resultFuture::set;
8581
}
8682

8783
@Override
@@ -106,7 +102,7 @@ public void close() throws IOException {
106102
if (open && !finished) {
107103
WriteObjectRequest message = finishMessage(true);
108104
try {
109-
flush(ImmutableList.of(message));
105+
flush(ImmutableList.of(message), null, true);
110106
finished = true;
111107
} catch (RuntimeException e) {
112108
resultFuture.setException(e);
@@ -122,12 +118,13 @@ private long internalWrite(ByteBuffer[] srcs, int srcsOffset, int srcsLength, bo
122118
throw new ClosedChannelException();
123119
}
124120

121+
long begin = writeCtx.getConfirmedBytes().get();
122+
RewindableContent content = RewindableContent.of(srcs, srcsOffset, srcsLength);
125123
ChunkSegment[] data = chunkSegmenter.segmentBuffers(srcs, srcsOffset, srcsLength);
126124

127125
List<WriteObjectRequest> messages = new ArrayList<>();
128126

129127
boolean first = true;
130-
int bytesConsumed = 0;
131128
for (ChunkSegment datum : data) {
132129
Crc32cLengthKnown crc32c = datum.getCrc32c();
133130
ByteString b = datum.getB();
@@ -144,8 +141,13 @@ private long internalWrite(ByteBuffer[] srcs, int srcsOffset, int srcsLength, bo
144141
WriteObjectRequest.Builder builder =
145142
writeCtx
146143
.newRequestBuilder()
144+
.clearWriteObjectSpec()
145+
.clearObjectChecksums()
147146
.setWriteOffset(offset)
148147
.setChecksummedData(checksummedData.build());
148+
if (!first) {
149+
builder.clearUploadId();
150+
}
149151
if (!datum.isOnlyFullBlocks()) {
150152
builder.setFinishWrite(true);
151153
if (cumulative != null) {
@@ -155,23 +157,25 @@ private long internalWrite(ByteBuffer[] srcs, int srcsOffset, int srcsLength, bo
155157
finished = true;
156158
}
157159

158-
WriteObjectRequest build = possiblyPairDownRequest(builder, first).build();
160+
WriteObjectRequest build = builder.build();
159161
first = false;
160162
messages.add(build);
161-
bytesConsumed += contentSize;
162163
}
163164
if (finalize && !finished) {
164165
messages.add(finishMessage(first));
165166
finished = true;
166167
}
167168

168169
try {
169-
flush(messages);
170+
flush(messages, content, finalize);
170171
} catch (RuntimeException e) {
171172
resultFuture.setException(e);
172173
throw e;
173174
}
174175

176+
long end = writeCtx.getConfirmedBytes().get();
177+
178+
long bytesConsumed = end - begin;
175179
return bytesConsumed;
176180
}
177181

@@ -182,14 +186,20 @@ private WriteObjectRequest finishMessage(boolean first) {
182186

183187
WriteObjectRequest.Builder b =
184188
writeCtx.newRequestBuilder().setFinishWrite(true).setWriteOffset(offset);
189+
if (!first) {
190+
b.clearUploadId();
191+
}
185192
if (crc32cValue != null) {
186193
b.setObjectChecksums(ObjectChecksums.newBuilder().setCrc32C(crc32cValue.getValue()).build());
187194
}
188-
WriteObjectRequest message = possiblyPairDownRequest(b, first).build();
195+
WriteObjectRequest message = b.build();
189196
return message;
190197
}
191198

192-
private void flush(@NonNull List<WriteObjectRequest> segments) {
199+
private void flush(
200+
@NonNull List<WriteObjectRequest> segments,
201+
@Nullable RewindableContent content,
202+
boolean finalizing) {
193203
GrpcCallContext internalContext = contextWithBucketName(bucketName, baseContextSupplier.get());
194204
ClientStreamingCallable<WriteObjectRequest, WriteObjectResponse> callable =
195205
write.withDefaultCallContext(internalContext);
@@ -198,7 +208,7 @@ private void flush(@NonNull List<WriteObjectRequest> segments) {
198208
deps,
199209
alg,
200210
() -> {
201-
Observer observer = new Observer(sizeCallback, completeCallback);
211+
Observer observer = new Observer(content, finalizing);
202212
ApiStreamObserver<WriteObjectRequest> write = callable.clientStreamingCall(observer);
203213

204214
for (WriteObjectRequest message : segments) {
@@ -211,81 +221,93 @@ private void flush(@NonNull List<WriteObjectRequest> segments) {
211221
Decoder.identity());
212222
}
213223

214-
/**
215-
* Several fields of a WriteObjectRequest are only allowed on the "first" message sent to gcs,
216-
* this utility method centralizes the logic necessary to clear those fields for use by subsequent
217-
* messages.
218-
*/
219-
private static WriteObjectRequest.Builder possiblyPairDownRequest(
220-
WriteObjectRequest.Builder b, boolean firstMessageOfStream) {
221-
if (firstMessageOfStream && b.getWriteOffset() == 0) {
222-
return b;
223-
}
224-
225-
if (!firstMessageOfStream) {
226-
b.clearUploadId();
227-
}
228-
229-
if (b.getWriteOffset() > 0) {
230-
b.clearWriteObjectSpec();
231-
}
232-
233-
if (b.getWriteOffset() > 0 && !b.getFinishWrite()) {
234-
b.clearObjectChecksums();
235-
}
236-
return b;
237-
}
238-
239224
@VisibleForTesting
240225
WriteCtx<?> getWriteCtx() {
241226
return writeCtx;
242227
}
243228

244-
static class Observer implements ApiStreamObserver<WriteObjectResponse> {
229+
class Observer implements ApiStreamObserver<WriteObjectResponse> {
245230

246-
private final LongConsumer sizeCallback;
247-
private final Consumer<WriteObjectResponse> completeCallback;
231+
private final RewindableContent content;
232+
private final boolean finalizing;
248233

249234
private final SettableApiFuture<Void> invocationHandle;
250235
private volatile WriteObjectResponse last;
251236

252-
Observer(LongConsumer sizeCallback, Consumer<WriteObjectResponse> completeCallback) {
253-
this.sizeCallback = sizeCallback;
254-
this.completeCallback = completeCallback;
237+
Observer(@Nullable RewindableContent content, boolean finalizing) {
238+
this.content = content;
239+
this.finalizing = finalizing;
255240
this.invocationHandle = SettableApiFuture.create();
256241
}
257242

258243
@Override
259244
public void onNext(WriteObjectResponse value) {
260-
// incremental update
261-
if (value.hasPersistedSize()) {
262-
sizeCallback.accept(value.getPersistedSize());
263-
} else if (value.hasResource()) {
264-
sizeCallback.accept(value.getResource().getSize());
265-
}
266245
last = value;
267246
}
268247

269-
/**
270-
* observed exceptions so far
271-
*
272-
* <ol>
273-
* <li>{@link com.google.api.gax.rpc.OutOfRangeException}
274-
* <li>{@link com.google.api.gax.rpc.AlreadyExistsException}
275-
* <li>{@link io.grpc.StatusRuntimeException}
276-
* </ol>
277-
*/
278248
@Override
279249
public void onError(Throwable t) {
280-
invocationHandle.setException(t);
250+
if (t instanceof OutOfRangeException) {
251+
OutOfRangeException oore = (OutOfRangeException) t;
252+
open = false;
253+
invocationHandle.setException(
254+
ResumableSessionFailureScenario.SCENARIO_5.toStorageException());
255+
} else {
256+
invocationHandle.setException(t);
257+
}
281258
}
282259

283260
@Override
284261
public void onCompleted() {
285-
if (last != null && last.hasResource()) {
286-
completeCallback.accept(last);
262+
try {
263+
if (last == null) {
264+
throw new StorageException(
265+
0, "onComplete without preceding onNext, unable to determine success.");
266+
} else if (!finalizing && last.hasPersistedSize()) { // incremental
267+
long totalSentBytes = writeCtx.getTotalSentBytes().get();
268+
long persistedSize = last.getPersistedSize();
269+
270+
if (totalSentBytes == persistedSize) {
271+
writeCtx.getConfirmedBytes().set(persistedSize);
272+
} else if (persistedSize < totalSentBytes) {
273+
long delta = totalSentBytes - persistedSize;
274+
// rewind our content and any state that my have run ahead of the actual ack'd bytes
275+
content.rewindTo(delta);
276+
writeCtx.getTotalSentBytes().set(persistedSize);
277+
writeCtx.getConfirmedBytes().set(persistedSize);
278+
} else {
279+
throw ResumableSessionFailureScenario.SCENARIO_7.toStorageException();
280+
}
281+
} else if (finalizing && last.hasResource()) {
282+
long totalSentBytes = writeCtx.getTotalSentBytes().get();
283+
long finalSize = last.getResource().getSize();
284+
if (totalSentBytes == finalSize) {
285+
writeCtx.getConfirmedBytes().set(finalSize);
286+
resultFuture.set(last);
287+
} else if (finalSize < totalSentBytes) {
288+
throw ResumableSessionFailureScenario.SCENARIO_4_1.toStorageException();
289+
} else {
290+
throw ResumableSessionFailureScenario.SCENARIO_4_2.toStorageException();
291+
}
292+
} else if (!finalizing && last.hasResource()) {
293+
throw ResumableSessionFailureScenario.SCENARIO_1.toStorageException();
294+
} else if (finalizing && last.hasPersistedSize()) {
295+
long totalSentBytes = writeCtx.getTotalSentBytes().get();
296+
long persistedSize = last.getPersistedSize();
297+
if (persistedSize < totalSentBytes) {
298+
throw ResumableSessionFailureScenario.SCENARIO_3.toStorageException();
299+
} else {
300+
throw ResumableSessionFailureScenario.SCENARIO_2.toStorageException();
301+
}
302+
} else {
303+
throw ResumableSessionFailureScenario.SCENARIO_0.toStorageException();
304+
}
305+
} catch (Throwable se) {
306+
open = false;
307+
invocationHandle.setException(se);
308+
} finally {
309+
invocationHandle.set(null);
287310
}
288-
invocationHandle.set(null);
289311
}
290312

291313
void await() {

google-cloud-storage/src/main/java/com/google/cloud/storage/GapicWritableByteChannelSessionBuilder.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -299,13 +299,13 @@ UnbufferedWritableByteChannelSession<WriteObjectResponse> build() {
299299
result,
300300
getChunkSegmenter(),
301301
write,
302-
ResumableWrite.identity(start),
302+
new WriteCtx<>(start),
303303
deps,
304304
alg,
305305
Retrying::newCallContext);
306306
} else {
307307
return new GapicUnbufferedFinalizeOnCloseResumableWritableByteChannel(
308-
result, getChunkSegmenter(), write, ResumableWrite.identity(start));
308+
result, getChunkSegmenter(), write, start);
309309
}
310310
})
311311
.andThen(StorageByteChannels.writable()::createSynchronized));
@@ -340,13 +340,13 @@ BufferedWritableByteChannelSession<WriteObjectResponse> build() {
340340
result,
341341
getChunkSegmenter(),
342342
write,
343-
ResumableWrite.identity(start),
343+
new WriteCtx<>(start),
344344
deps,
345345
alg,
346346
Retrying::newCallContext);
347347
} else {
348348
return new GapicUnbufferedFinalizeOnCloseResumableWritableByteChannel(
349-
result, getChunkSegmenter(), write, ResumableWrite.identity(start));
349+
result, getChunkSegmenter(), write, start);
350350
}
351351
})
352352
.andThen(c -> new DefaultBufferedWritableByteChannel(bufferHandle, c))

0 commit comments

Comments
 (0)