Skip to content

Commit

Permalink
Make SplittablePayload extend OpMsgSequence
Browse files Browse the repository at this point in the history
Justification: This brings SplittablePayload closer in design to
DualSplittablePayloads, reducing a potential source of confusion
for future readers. We could go further in this direction, but
this is a start.
  • Loading branch information
jyemin committed Sep 15, 2024
1 parent 7d7f3eb commit a2805a6
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,8 @@ boolean isResponseExpected() {
if (responseExpected) {
return true;
} else {
if (sequences instanceof ValidatableSplittablePayload) {
ValidatableSplittablePayload validatableSplittablePayload = (ValidatableSplittablePayload) sequences;
SplittablePayload payload = validatableSplittablePayload.getSplittablePayload();
if (sequences instanceof SplittablePayload) {
SplittablePayload payload = (SplittablePayload) sequences;
return payload.isOrdered() && payload.hasAnotherSplit();
} else if (sequences instanceof DualSplittablePayloads) {
return assertNotNull(dualSplittablePayloadsRequireResponse);
Expand All @@ -235,13 +234,12 @@ protected EncodingMetadata encodeMessageBodyWithMetadata(final ByteBufferBsonOut
ArrayList<BsonElement> extraElements = getExtraElements(operationContext);

int commandDocumentSizeInBytes = writeDocument(command, bsonOutput, commandFieldNameValidator, true);
if (sequences instanceof ValidatableSplittablePayload) {
if (sequences instanceof SplittablePayload) {
appendElementsToDocument(bsonOutput, commandStartPosition, extraElements);
ValidatableSplittablePayload validatableSplittablePayload = (ValidatableSplittablePayload) sequences;
SplittablePayload payload = validatableSplittablePayload.getSplittablePayload();
SplittablePayload payload = (SplittablePayload) sequences;
writeOpMsgSectionWithPayloadType1(bsonOutput, payload.getPayloadName(), () -> {
writePayload(
new BsonBinaryWriter(bsonOutput, validatableSplittablePayload.getFieldNameValidator()),
new BsonBinaryWriter(bsonOutput, payload.getFieldNameValidator()),
bsonOutput, getSettings(), messageStartPosition, payload, getSettings().getMaxDocumentSize()
);
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.bson.BsonObjectId;
import org.bson.BsonValue;
import org.bson.BsonWriter;
import org.bson.FieldNameValidator;
import org.bson.codecs.BsonValueCodecProvider;
import org.bson.codecs.Codec;
import org.bson.codecs.Encoder;
Expand Down Expand Up @@ -54,8 +55,9 @@
*
* <p>This class is not part of the public API and may be removed or changed at any time</p>
*/
public final class SplittablePayload {
public final class SplittablePayload extends OpMsgSequences {
private static final CodecRegistry REGISTRY = fromProviders(new BsonValueCodecProvider());
private final FieldNameValidator fieldNameValidator;
private final WriteRequestEncoder writeRequestEncoder = new WriteRequestEncoder();
private final Type payloadType;
private final List<WriteRequestWithIndex> writeRequestWithIndexes;
Expand Down Expand Up @@ -94,10 +96,16 @@ public enum Type {
* @param payloadType the payload type
* @param writeRequestWithIndexes the writeRequests
*/
public SplittablePayload(final Type payloadType, final List<WriteRequestWithIndex> writeRequestWithIndexes, final boolean ordered) {
public SplittablePayload(final Type payloadType, final List<WriteRequestWithIndex> writeRequestWithIndexes, final boolean ordered,
final FieldNameValidator fieldNameValidator) {
this.payloadType = notNull("batchType", payloadType);
this.writeRequestWithIndexes = notNull("writeRequests", writeRequestWithIndexes);
this.ordered = ordered;
this.fieldNameValidator = fieldNameValidator;
}

public FieldNameValidator getFieldNameValidator() {
return fieldNameValidator;
}

/**
Expand Down Expand Up @@ -175,7 +183,7 @@ boolean isOrdered() {
public SplittablePayload getNextSplit() {
isTrue("hasAnotherSplit", hasAnotherSplit());
List<WriteRequestWithIndex> nextPayLoad = writeRequestWithIndexes.subList(position, writeRequestWithIndexes.size());
return new SplittablePayload(payloadType, nextPayLoad, ordered);
return new SplittablePayload(payloadType, nextPayLoad, ordered, fieldNameValidator);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@
import static com.mongodb.internal.bulk.WriteRequest.Type.INSERT;
import static com.mongodb.internal.bulk.WriteRequest.Type.REPLACE;
import static com.mongodb.internal.bulk.WriteRequest.Type.UPDATE;
import static com.mongodb.internal.operation.DocumentHelper.putIfNotNull;
import static com.mongodb.internal.operation.CommandOperationHelper.commandWriteConcern;
import static com.mongodb.internal.operation.DocumentHelper.putIfNotNull;
import static com.mongodb.internal.operation.OperationHelper.LOGGER;
import static com.mongodb.internal.operation.OperationHelper.isRetryableWrite;
import static com.mongodb.internal.operation.WriteConcernHelper.createWriteConcernError;
Expand Down Expand Up @@ -154,7 +154,7 @@ private BulkWriteBatch(final MongoNamespace namespace, final ConnectionDescripti

this.indexMap = indexMap;
this.unprocessed = unprocessedItems;
this.payload = new SplittablePayload(getPayloadType(batchType), payloadItems, ordered);
this.payload = new SplittablePayload(getPayloadType(batchType), payloadItems, ordered, getFieldNameValidator());
this.operationContext = operationContext;
this.comment = comment;
this.variables = variables;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import com.mongodb.internal.connection.MongoWriteConcernWithResponseException;
import com.mongodb.internal.connection.OperationContext;
import com.mongodb.internal.connection.ProtocolHelper;
import com.mongodb.internal.connection.ValidatableSplittablePayload;
import com.mongodb.internal.operation.retry.AttachmentKeys;
import com.mongodb.internal.session.SessionContext;
import com.mongodb.internal.validator.NoOpFieldNameValidator;
Expand Down Expand Up @@ -422,8 +421,7 @@ private BsonDocument executeCommand(
final Connection connection,
final BulkWriteBatch batch) {
return connection.command(namespace.getDatabaseName(), batch.getCommand(), NoOpFieldNameValidator.INSTANCE, null, batch.getDecoder(),
operationContext, shouldExpectResponse(batch, effectiveWriteConcern),
new ValidatableSplittablePayload(batch.getPayload(), batch.getFieldNameValidator()));
operationContext, shouldExpectResponse(batch, effectiveWriteConcern), batch.getPayload());
}

private void executeCommandAsync(
Expand All @@ -433,8 +431,7 @@ private void executeCommandAsync(
final BulkWriteBatch batch,
final SingleResultCallback<BsonDocument> callback) {
connection.commandAsync(namespace.getDatabaseName(), batch.getCommand(), NoOpFieldNameValidator.INSTANCE, null, batch.getDecoder(),
operationContext, shouldExpectResponse(batch, effectiveWriteConcern),
new ValidatableSplittablePayload(batch.getPayload(), batch.getFieldNameValidator()), callback);
operationContext, shouldExpectResponse(batch, effectiveWriteConcern), batch.getPayload(), callback);
}

private boolean shouldExpectResponse(final BulkWriteBatch batch, final WriteConcern effectiveWriteConcern) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class CommandMessageSpecification extends Specification {
MessageSettings.builder().maxWireVersion(maxWireVersion).build(), true,
payload == null
? OpMsgSequences.EmptyOpMsgSequences.INSTANCE
: new ValidatableSplittablePayload(payload, NoOpFieldNameValidator.INSTANCE),
: payload,
ClusterConnectionMode.MULTIPLE, null)
def output = new ByteBufferBsonOutput(new SimpleBufferProvider())
message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE,
Expand All @@ -177,7 +177,8 @@ class CommandMessageSpecification extends Specification {
new BsonDocument('insert', new BsonString('coll')),
new SplittablePayload(INSERT, [new BsonDocument('_id', new BsonInt32(1)),
new BsonDocument('_id', new BsonInt32(2))]
.withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true),
.withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true,
NoOpFieldNameValidator.INSTANCE),
],
[
LATEST_WIRE_VERSION,
Expand All @@ -198,9 +199,9 @@ class CommandMessageSpecification extends Specification {
new BsonDocument('_id', new BsonInt32(3)).append('c', new BsonBinary(new byte[450])),
new BsonDocument('_id', new BsonInt32(4)).append('b', new BsonBinary(new byte[441])),
new BsonDocument('_id', new BsonInt32(5)).append('c', new BsonBinary(new byte[451]))]
.withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true)
.withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, fieldNameValidator)
def message = new CommandMessage(namespace, insertCommand, fieldNameValidator, ReadPreference.primary(), messageSettings,
false, new ValidatableSplittablePayload(payload, fieldNameValidator), ClusterConnectionMode.MULTIPLE, null)
false, payload, ClusterConnectionMode.MULTIPLE, null)
def output = new ByteBufferBsonOutput(new SimpleBufferProvider())
def sessionContext = Stub(SessionContext) {
getReadConcern() >> ReadConcern.DEFAULT
Expand All @@ -224,7 +225,7 @@ class CommandMessageSpecification extends Specification {
when:
payload = payload.getNextSplit()
message = new CommandMessage(namespace, insertCommand, fieldNameValidator, ReadPreference.primary(), messageSettings,
false, new ValidatableSplittablePayload(payload, fieldNameValidator), ClusterConnectionMode.MULTIPLE, null)
false, payload, ClusterConnectionMode.MULTIPLE, null)
output.truncateToPosition(0)
message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, Stub(TimeoutContext), null))
byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray()))
Expand All @@ -242,7 +243,7 @@ class CommandMessageSpecification extends Specification {
when:
payload = payload.getNextSplit()
message = new CommandMessage(namespace, insertCommand, fieldNameValidator, ReadPreference.primary(), messageSettings,
false, new ValidatableSplittablePayload(payload, fieldNameValidator), ClusterConnectionMode.MULTIPLE, null)
false, payload, ClusterConnectionMode.MULTIPLE, null)
output.truncateToPosition(0)
message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, Stub(TimeoutContext), null))
byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray()))
Expand All @@ -260,7 +261,7 @@ class CommandMessageSpecification extends Specification {
when:
payload = payload.getNextSplit()
message = new CommandMessage(namespace, insertCommand, fieldNameValidator, ReadPreference.primary(), messageSettings,
false, new ValidatableSplittablePayload(payload, fieldNameValidator), ClusterConnectionMode.MULTIPLE, null)
false, payload, ClusterConnectionMode.MULTIPLE, null)
output.truncateToPosition(0)
message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE,
sessionContext,
Expand Down Expand Up @@ -288,9 +289,9 @@ class CommandMessageSpecification extends Specification {
def payload = new SplittablePayload(INSERT, [new BsonDocument('a', new BsonBinary(new byte[900])),
new BsonDocument('b', new BsonBinary(new byte[450])),
new BsonDocument('c', new BsonBinary(new byte[450]))]
.withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true)
.withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, fieldNameValidator)
def message = new CommandMessage(namespace, command, fieldNameValidator, ReadPreference.primary(), messageSettings,
false, new ValidatableSplittablePayload(payload, fieldNameValidator), ClusterConnectionMode.MULTIPLE, null)
false, payload, ClusterConnectionMode.MULTIPLE, null)
def output = new ByteBufferBsonOutput(new SimpleBufferProvider())
def sessionContext = Stub(SessionContext) {
getReadConcern() >> ReadConcern.DEFAULT
Expand All @@ -315,7 +316,7 @@ class CommandMessageSpecification extends Specification {
when:
payload = payload.getNextSplit()
message = new CommandMessage(namespace, command, fieldNameValidator, ReadPreference.primary(), messageSettings,
false, new ValidatableSplittablePayload(payload, fieldNameValidator), ClusterConnectionMode.MULTIPLE, null)
false, payload, ClusterConnectionMode.MULTIPLE, null)
output.truncateToPosition(0)
message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext,
Stub(TimeoutContext), null))
Expand All @@ -339,9 +340,9 @@ class CommandMessageSpecification extends Specification {
def messageSettings = MessageSettings.builder().maxDocumentSize(900)
.maxWireVersion(LATEST_WIRE_VERSION).build()
def payload = new SplittablePayload(INSERT, [new BsonDocument('a', new BsonBinary(new byte[900]))]
.withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true)
.withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, fieldNameValidator)
def message = new CommandMessage(namespace, command, fieldNameValidator, ReadPreference.primary(), messageSettings,
false, new ValidatableSplittablePayload(payload, fieldNameValidator), ClusterConnectionMode.MULTIPLE, null)
false, payload, ClusterConnectionMode.MULTIPLE, null)
def output = new ByteBufferBsonOutput(new SimpleBufferProvider())
def sessionContext = Stub(SessionContext) {
getReadConcern() >> ReadConcern.DEFAULT
Expand All @@ -362,9 +363,9 @@ class CommandMessageSpecification extends Specification {
given:
def messageSettings = MessageSettings.builder().serverType(ServerType.SHARD_ROUTER)
.maxWireVersion(FOUR_DOT_ZERO_WIRE_VERSION).build()
def payload = new SplittablePayload(INSERT, [new BsonDocument('a', new BsonInt32(1))], true)
def payload = new SplittablePayload(INSERT, [new BsonDocument('a', new BsonInt32(1))], true, fieldNameValidator)
def message = new CommandMessage(namespace, command, fieldNameValidator, ReadPreference.primary(), messageSettings,
false, new ValidatableSplittablePayload(payload, fieldNameValidator), ClusterConnectionMode.MULTIPLE, null)
false, payload, ClusterConnectionMode.MULTIPLE, null)
def output = new ByteBufferBsonOutput(new SimpleBufferProvider())
def sessionContext = Stub(SessionContext) {
getReadConcern() >> ReadConcern.DEFAULT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import com.mongodb.internal.connection.OperationContext;
import com.mongodb.internal.connection.SplittablePayload;
import com.mongodb.internal.connection.SplittablePayloadBsonWriter;
import com.mongodb.internal.connection.ValidatableSplittablePayload;
import com.mongodb.internal.time.Timeout;
import com.mongodb.internal.validator.MappedFieldNameValidator;
import com.mongodb.lang.Nullable;
Expand Down Expand Up @@ -113,10 +112,9 @@ public <T> void commandAsync(final String database, final BsonDocument command,
try {
SplittablePayload payload = null;
FieldNameValidator payloadFieldNameValidator = null;
if (sequences instanceof ValidatableSplittablePayload) {
ValidatableSplittablePayload validatableSplittablePayload = (ValidatableSplittablePayload) sequences;
payload = validatableSplittablePayload.getSplittablePayload();
payloadFieldNameValidator = validatableSplittablePayload.getFieldNameValidator();
if (sequences instanceof SplittablePayload) {
payload = (SplittablePayload) sequences;
payloadFieldNameValidator = payload.getFieldNameValidator();
}
BasicOutputBuffer bsonOutput = new BasicOutputBuffer();
BsonBinaryWriter bsonBinaryWriter = new BsonBinaryWriter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import com.mongodb.internal.connection.OperationContext;
import com.mongodb.internal.connection.SplittablePayload;
import com.mongodb.internal.connection.SplittablePayloadBsonWriter;
import com.mongodb.internal.connection.ValidatableSplittablePayload;
import com.mongodb.internal.time.Timeout;
import com.mongodb.internal.validator.MappedFieldNameValidator;
import com.mongodb.lang.Nullable;
Expand Down Expand Up @@ -101,10 +100,9 @@ public <T> T command(final String database, final BsonDocument command, final Fi

SplittablePayload payload = null;
FieldNameValidator payloadFieldNameValidator = null;
if (sequences instanceof ValidatableSplittablePayload) {
ValidatableSplittablePayload validatableSplittablePayload = (ValidatableSplittablePayload) sequences;
payload = validatableSplittablePayload.getSplittablePayload();
payloadFieldNameValidator = validatableSplittablePayload.getFieldNameValidator();
if (sequences instanceof SplittablePayload) {
payload = (SplittablePayload) sequences;
payloadFieldNameValidator = payload.getFieldNameValidator();
}
BasicOutputBuffer bsonOutput = new BasicOutputBuffer();
BsonBinaryWriter bsonBinaryWriter = new BsonBinaryWriter(new BsonWriterSettings(),
Expand Down
Loading

0 comments on commit a2805a6

Please sign in to comment.