diff --git a/driver-core/src/main/com/mongodb/internal/connection/CommandMessage.java b/driver-core/src/main/com/mongodb/internal/connection/CommandMessage.java index 3a88f7b102..68717e6ff4 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/CommandMessage.java +++ b/driver-core/src/main/com/mongodb/internal/connection/CommandMessage.java @@ -206,9 +206,8 @@ boolean isResponseExpected() { if (responseExpected) { return true; } else { - if (sequences instanceof ValidatableSplittablePayload) { - ValidatableSplittablePayload validatableSplittablePayload = (ValidatableSplittablePayload) sequences; - SplittablePayload payload = validatableSplittablePayload.getSplittablePayload(); + if (sequences instanceof SplittablePayload) { + SplittablePayload payload = (SplittablePayload) sequences; return payload.isOrdered() && payload.hasAnotherSplit(); } else if (sequences instanceof DualSplittablePayloads) { return assertNotNull(dualSplittablePayloadsRequireResponse); @@ -235,13 +234,12 @@ protected EncodingMetadata encodeMessageBodyWithMetadata(final ByteBufferBsonOut ArrayList extraElements = getExtraElements(operationContext); int commandDocumentSizeInBytes = writeDocument(command, bsonOutput, commandFieldNameValidator, true); - if (sequences instanceof ValidatableSplittablePayload) { + if (sequences instanceof SplittablePayload) { appendElementsToDocument(bsonOutput, commandStartPosition, extraElements); - ValidatableSplittablePayload validatableSplittablePayload = (ValidatableSplittablePayload) sequences; - SplittablePayload payload = validatableSplittablePayload.getSplittablePayload(); + SplittablePayload payload = (SplittablePayload) sequences; writeOpMsgSectionWithPayloadType1(bsonOutput, payload.getPayloadName(), () -> { writePayload( - new BsonBinaryWriter(bsonOutput, validatableSplittablePayload.getFieldNameValidator()), + new BsonBinaryWriter(bsonOutput, payload.getFieldNameValidator()), bsonOutput, getSettings(), messageStartPosition, payload, getSettings().getMaxDocumentSize() ); return null; diff --git a/driver-core/src/main/com/mongodb/internal/connection/SplittablePayload.java b/driver-core/src/main/com/mongodb/internal/connection/SplittablePayload.java index d628a39238..00c355cc66 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/SplittablePayload.java +++ b/driver-core/src/main/com/mongodb/internal/connection/SplittablePayload.java @@ -26,6 +26,7 @@ import org.bson.BsonObjectId; import org.bson.BsonValue; import org.bson.BsonWriter; +import org.bson.FieldNameValidator; import org.bson.codecs.BsonValueCodecProvider; import org.bson.codecs.Codec; import org.bson.codecs.Encoder; @@ -54,8 +55,9 @@ * *

This class is not part of the public API and may be removed or changed at any time

*/ -public final class SplittablePayload { +public final class SplittablePayload extends OpMsgSequences { private static final CodecRegistry REGISTRY = fromProviders(new BsonValueCodecProvider()); + private final FieldNameValidator fieldNameValidator; private final WriteRequestEncoder writeRequestEncoder = new WriteRequestEncoder(); private final Type payloadType; private final List writeRequestWithIndexes; @@ -94,10 +96,16 @@ public enum Type { * @param payloadType the payload type * @param writeRequestWithIndexes the writeRequests */ - public SplittablePayload(final Type payloadType, final List writeRequestWithIndexes, final boolean ordered) { + public SplittablePayload(final Type payloadType, final List writeRequestWithIndexes, final boolean ordered, + final FieldNameValidator fieldNameValidator) { this.payloadType = notNull("batchType", payloadType); this.writeRequestWithIndexes = notNull("writeRequests", writeRequestWithIndexes); this.ordered = ordered; + this.fieldNameValidator = fieldNameValidator; + } + + public FieldNameValidator getFieldNameValidator() { + return fieldNameValidator; } /** @@ -175,7 +183,7 @@ boolean isOrdered() { public SplittablePayload getNextSplit() { isTrue("hasAnotherSplit", hasAnotherSplit()); List nextPayLoad = writeRequestWithIndexes.subList(position, writeRequestWithIndexes.size()); - return new SplittablePayload(payloadType, nextPayLoad, ordered); + return new SplittablePayload(payloadType, nextPayLoad, ordered, fieldNameValidator); } /** diff --git a/driver-core/src/main/com/mongodb/internal/operation/BulkWriteBatch.java b/driver-core/src/main/com/mongodb/internal/operation/BulkWriteBatch.java index 8da0f13e31..2b237ac2ee 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/BulkWriteBatch.java +++ b/driver-core/src/main/com/mongodb/internal/operation/BulkWriteBatch.java @@ -63,8 +63,8 @@ import static com.mongodb.internal.bulk.WriteRequest.Type.INSERT; import static com.mongodb.internal.bulk.WriteRequest.Type.REPLACE; import static com.mongodb.internal.bulk.WriteRequest.Type.UPDATE; -import static com.mongodb.internal.operation.DocumentHelper.putIfNotNull; import static com.mongodb.internal.operation.CommandOperationHelper.commandWriteConcern; +import static com.mongodb.internal.operation.DocumentHelper.putIfNotNull; import static com.mongodb.internal.operation.OperationHelper.LOGGER; import static com.mongodb.internal.operation.OperationHelper.isRetryableWrite; import static com.mongodb.internal.operation.WriteConcernHelper.createWriteConcernError; @@ -154,7 +154,7 @@ private BulkWriteBatch(final MongoNamespace namespace, final ConnectionDescripti this.indexMap = indexMap; this.unprocessed = unprocessedItems; - this.payload = new SplittablePayload(getPayloadType(batchType), payloadItems, ordered); + this.payload = new SplittablePayload(getPayloadType(batchType), payloadItems, ordered, getFieldNameValidator()); this.operationContext = operationContext; this.comment = comment; this.variables = variables; diff --git a/driver-core/src/main/com/mongodb/internal/operation/MixedBulkWriteOperation.java b/driver-core/src/main/com/mongodb/internal/operation/MixedBulkWriteOperation.java index bed397243a..06d392bceb 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/MixedBulkWriteOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/MixedBulkWriteOperation.java @@ -39,7 +39,6 @@ import com.mongodb.internal.connection.MongoWriteConcernWithResponseException; import com.mongodb.internal.connection.OperationContext; import com.mongodb.internal.connection.ProtocolHelper; -import com.mongodb.internal.connection.ValidatableSplittablePayload; import com.mongodb.internal.operation.retry.AttachmentKeys; import com.mongodb.internal.session.SessionContext; import com.mongodb.internal.validator.NoOpFieldNameValidator; @@ -422,8 +421,7 @@ private BsonDocument executeCommand( final Connection connection, final BulkWriteBatch batch) { return connection.command(namespace.getDatabaseName(), batch.getCommand(), NoOpFieldNameValidator.INSTANCE, null, batch.getDecoder(), - operationContext, shouldExpectResponse(batch, effectiveWriteConcern), - new ValidatableSplittablePayload(batch.getPayload(), batch.getFieldNameValidator())); + operationContext, shouldExpectResponse(batch, effectiveWriteConcern), batch.getPayload()); } private void executeCommandAsync( @@ -433,8 +431,7 @@ private void executeCommandAsync( final BulkWriteBatch batch, final SingleResultCallback callback) { connection.commandAsync(namespace.getDatabaseName(), batch.getCommand(), NoOpFieldNameValidator.INSTANCE, null, batch.getDecoder(), - operationContext, shouldExpectResponse(batch, effectiveWriteConcern), - new ValidatableSplittablePayload(batch.getPayload(), batch.getFieldNameValidator()), callback); + operationContext, shouldExpectResponse(batch, effectiveWriteConcern), batch.getPayload(), callback); } private boolean shouldExpectResponse(final BulkWriteBatch batch, final WriteConcern effectiveWriteConcern) { diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy index 9ec7d35ea8..27b88e7283 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy @@ -154,7 +154,7 @@ class CommandMessageSpecification extends Specification { MessageSettings.builder().maxWireVersion(maxWireVersion).build(), true, payload == null ? OpMsgSequences.EmptyOpMsgSequences.INSTANCE - : new ValidatableSplittablePayload(payload, NoOpFieldNameValidator.INSTANCE), + : payload, ClusterConnectionMode.MULTIPLE, null) def output = new ByteBufferBsonOutput(new SimpleBufferProvider()) message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, @@ -177,7 +177,8 @@ class CommandMessageSpecification extends Specification { new BsonDocument('insert', new BsonString('coll')), new SplittablePayload(INSERT, [new BsonDocument('_id', new BsonInt32(1)), new BsonDocument('_id', new BsonInt32(2))] - .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true), + .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, + NoOpFieldNameValidator.INSTANCE), ], [ LATEST_WIRE_VERSION, @@ -198,9 +199,9 @@ class CommandMessageSpecification extends Specification { new BsonDocument('_id', new BsonInt32(3)).append('c', new BsonBinary(new byte[450])), new BsonDocument('_id', new BsonInt32(4)).append('b', new BsonBinary(new byte[441])), new BsonDocument('_id', new BsonInt32(5)).append('c', new BsonBinary(new byte[451]))] - .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true) + .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, fieldNameValidator) def message = new CommandMessage(namespace, insertCommand, fieldNameValidator, ReadPreference.primary(), messageSettings, - false, new ValidatableSplittablePayload(payload, fieldNameValidator), ClusterConnectionMode.MULTIPLE, null) + false, payload, ClusterConnectionMode.MULTIPLE, null) def output = new ByteBufferBsonOutput(new SimpleBufferProvider()) def sessionContext = Stub(SessionContext) { getReadConcern() >> ReadConcern.DEFAULT @@ -224,7 +225,7 @@ class CommandMessageSpecification extends Specification { when: payload = payload.getNextSplit() message = new CommandMessage(namespace, insertCommand, fieldNameValidator, ReadPreference.primary(), messageSettings, - false, new ValidatableSplittablePayload(payload, fieldNameValidator), ClusterConnectionMode.MULTIPLE, null) + false, payload, ClusterConnectionMode.MULTIPLE, null) output.truncateToPosition(0) message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, Stub(TimeoutContext), null)) byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray())) @@ -242,7 +243,7 @@ class CommandMessageSpecification extends Specification { when: payload = payload.getNextSplit() message = new CommandMessage(namespace, insertCommand, fieldNameValidator, ReadPreference.primary(), messageSettings, - false, new ValidatableSplittablePayload(payload, fieldNameValidator), ClusterConnectionMode.MULTIPLE, null) + false, payload, ClusterConnectionMode.MULTIPLE, null) output.truncateToPosition(0) message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, Stub(TimeoutContext), null)) byteBuf = new ByteBufNIO(ByteBuffer.wrap(output.toByteArray())) @@ -260,7 +261,7 @@ class CommandMessageSpecification extends Specification { when: payload = payload.getNextSplit() message = new CommandMessage(namespace, insertCommand, fieldNameValidator, ReadPreference.primary(), messageSettings, - false, new ValidatableSplittablePayload(payload, fieldNameValidator), ClusterConnectionMode.MULTIPLE, null) + false, payload, ClusterConnectionMode.MULTIPLE, null) output.truncateToPosition(0) message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, @@ -288,9 +289,9 @@ class CommandMessageSpecification extends Specification { def payload = new SplittablePayload(INSERT, [new BsonDocument('a', new BsonBinary(new byte[900])), new BsonDocument('b', new BsonBinary(new byte[450])), new BsonDocument('c', new BsonBinary(new byte[450]))] - .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true) + .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, fieldNameValidator) def message = new CommandMessage(namespace, command, fieldNameValidator, ReadPreference.primary(), messageSettings, - false, new ValidatableSplittablePayload(payload, fieldNameValidator), ClusterConnectionMode.MULTIPLE, null) + false, payload, ClusterConnectionMode.MULTIPLE, null) def output = new ByteBufferBsonOutput(new SimpleBufferProvider()) def sessionContext = Stub(SessionContext) { getReadConcern() >> ReadConcern.DEFAULT @@ -315,7 +316,7 @@ class CommandMessageSpecification extends Specification { when: payload = payload.getNextSplit() message = new CommandMessage(namespace, command, fieldNameValidator, ReadPreference.primary(), messageSettings, - false, new ValidatableSplittablePayload(payload, fieldNameValidator), ClusterConnectionMode.MULTIPLE, null) + false, payload, ClusterConnectionMode.MULTIPLE, null) output.truncateToPosition(0) message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext, Stub(TimeoutContext), null)) @@ -339,9 +340,9 @@ class CommandMessageSpecification extends Specification { def messageSettings = MessageSettings.builder().maxDocumentSize(900) .maxWireVersion(LATEST_WIRE_VERSION).build() def payload = new SplittablePayload(INSERT, [new BsonDocument('a', new BsonBinary(new byte[900]))] - .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true) + .withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, fieldNameValidator) def message = new CommandMessage(namespace, command, fieldNameValidator, ReadPreference.primary(), messageSettings, - false, new ValidatableSplittablePayload(payload, fieldNameValidator), ClusterConnectionMode.MULTIPLE, null) + false, payload, ClusterConnectionMode.MULTIPLE, null) def output = new ByteBufferBsonOutput(new SimpleBufferProvider()) def sessionContext = Stub(SessionContext) { getReadConcern() >> ReadConcern.DEFAULT @@ -362,9 +363,9 @@ class CommandMessageSpecification extends Specification { given: def messageSettings = MessageSettings.builder().serverType(ServerType.SHARD_ROUTER) .maxWireVersion(FOUR_DOT_ZERO_WIRE_VERSION).build() - def payload = new SplittablePayload(INSERT, [new BsonDocument('a', new BsonInt32(1))], true) + def payload = new SplittablePayload(INSERT, [new BsonDocument('a', new BsonInt32(1))], true, fieldNameValidator) def message = new CommandMessage(namespace, command, fieldNameValidator, ReadPreference.primary(), messageSettings, - false, new ValidatableSplittablePayload(payload, fieldNameValidator), ClusterConnectionMode.MULTIPLE, null) + false, payload, ClusterConnectionMode.MULTIPLE, null) def output = new ByteBufferBsonOutput(new SimpleBufferProvider()) def sessionContext = Stub(SessionContext) { getReadConcern() >> ReadConcern.DEFAULT diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/CryptConnection.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/CryptConnection.java index 4d7d41c97e..576cdae227 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/CryptConnection.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/CryptConnection.java @@ -28,7 +28,6 @@ import com.mongodb.internal.connection.OperationContext; import com.mongodb.internal.connection.SplittablePayload; import com.mongodb.internal.connection.SplittablePayloadBsonWriter; -import com.mongodb.internal.connection.ValidatableSplittablePayload; import com.mongodb.internal.time.Timeout; import com.mongodb.internal.validator.MappedFieldNameValidator; import com.mongodb.lang.Nullable; @@ -113,10 +112,9 @@ public void commandAsync(final String database, final BsonDocument command, try { SplittablePayload payload = null; FieldNameValidator payloadFieldNameValidator = null; - if (sequences instanceof ValidatableSplittablePayload) { - ValidatableSplittablePayload validatableSplittablePayload = (ValidatableSplittablePayload) sequences; - payload = validatableSplittablePayload.getSplittablePayload(); - payloadFieldNameValidator = validatableSplittablePayload.getFieldNameValidator(); + if (sequences instanceof SplittablePayload) { + payload = (SplittablePayload) sequences; + payloadFieldNameValidator = payload.getFieldNameValidator(); } BasicOutputBuffer bsonOutput = new BasicOutputBuffer(); BsonBinaryWriter bsonBinaryWriter = new BsonBinaryWriter( diff --git a/driver-sync/src/main/com/mongodb/client/internal/CryptConnection.java b/driver-sync/src/main/com/mongodb/client/internal/CryptConnection.java index 5ff45ae723..cdb42dde8c 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/CryptConnection.java +++ b/driver-sync/src/main/com/mongodb/client/internal/CryptConnection.java @@ -26,7 +26,6 @@ import com.mongodb.internal.connection.OperationContext; import com.mongodb.internal.connection.SplittablePayload; import com.mongodb.internal.connection.SplittablePayloadBsonWriter; -import com.mongodb.internal.connection.ValidatableSplittablePayload; import com.mongodb.internal.time.Timeout; import com.mongodb.internal.validator.MappedFieldNameValidator; import com.mongodb.lang.Nullable; @@ -101,10 +100,9 @@ public T command(final String database, final BsonDocument command, final Fi SplittablePayload payload = null; FieldNameValidator payloadFieldNameValidator = null; - if (sequences instanceof ValidatableSplittablePayload) { - ValidatableSplittablePayload validatableSplittablePayload = (ValidatableSplittablePayload) sequences; - payload = validatableSplittablePayload.getSplittablePayload(); - payloadFieldNameValidator = validatableSplittablePayload.getFieldNameValidator(); + if (sequences instanceof SplittablePayload) { + payload = (SplittablePayload) sequences; + payloadFieldNameValidator = payload.getFieldNameValidator(); } BasicOutputBuffer bsonOutput = new BasicOutputBuffer(); BsonBinaryWriter bsonBinaryWriter = new BsonBinaryWriter(new BsonWriterSettings(), diff --git a/driver-sync/src/test/unit/com/mongodb/client/internal/CryptConnectionSpecification.groovy b/driver-sync/src/test/unit/com/mongodb/client/internal/CryptConnectionSpecification.groovy index 5e81d36a96..b4f33aafea 100644 --- a/driver-sync/src/test/unit/com/mongodb/client/internal/CryptConnectionSpecification.groovy +++ b/driver-sync/src/test/unit/com/mongodb/client/internal/CryptConnectionSpecification.groovy @@ -29,7 +29,6 @@ import com.mongodb.internal.bulk.WriteRequestWithIndex import com.mongodb.internal.connection.Connection import com.mongodb.internal.connection.OpMsgSequences import com.mongodb.internal.connection.SplittablePayload -import com.mongodb.internal.connection.ValidatableSplittablePayload import com.mongodb.internal.time.Timeout import com.mongodb.internal.validator.NoOpFieldNameValidator import org.bson.BsonArray @@ -117,7 +116,7 @@ class CryptConnectionSpecification extends Specification { def payload = new SplittablePayload(INSERT, [ new BsonDocumentWrapper(new Document('_id', 1).append('ssid', '555-55-5555').append('b', bytes), codec), new BsonDocumentWrapper(new Document('_id', 2).append('ssid', '666-66-6666').append('b', bytes), codec) - ].withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true) + ].withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, NoOpFieldNameValidator.INSTANCE) def encryptedCommand = toRaw(new BsonDocument('insert', new BsonString('test')).append('documents', new BsonArray( [ new BsonDocument('_id', new BsonInt32(1)) @@ -136,7 +135,7 @@ class CryptConnectionSpecification extends Specification { def response = cryptConnection.command('db', new BsonDocumentWrapper(new Document('insert', 'test'), codec), NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), new BsonDocumentCodec(), - operationContext, true, new ValidatableSplittablePayload(payload, NoOpFieldNameValidator.INSTANCE)) + operationContext, true, payload) then: _ * wrappedConnection.getDescription() >> { @@ -173,8 +172,8 @@ class CryptConnectionSpecification extends Specification { def payload = new SplittablePayload(INSERT, [ new BsonDocumentWrapper(new Document('_id', 1), codec), new BsonDocumentWrapper(new Document('_id', 2), codec), - new BsonDocumentWrapper(new Document('_id', 3), codec) - ].withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true) + new BsonDocumentWrapper(new Document('_id', 3), codec,) + ].withIndex().collect { doc, i -> new WriteRequestWithIndex(new InsertRequest(doc), i) }, true, NoOpFieldNameValidator.INSTANCE) def encryptedCommand = toRaw(new BsonDocument('insert', new BsonString('test')).append('documents', new BsonArray( [ new BsonDocument('_id', new BsonInt32(1)), @@ -193,7 +192,7 @@ class CryptConnectionSpecification extends Specification { def response = cryptConnection.command('db', new BsonDocumentWrapper(new Document('insert', 'test'), codec), NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), new BsonDocumentCodec(), operationContext, true, - new ValidatableSplittablePayload(payload, NoOpFieldNameValidator.INSTANCE)) + payload) then: _ * wrappedConnection.getDescription() >> {