diff --git a/sdk/servicebus/service-bus/src/sender.ts b/sdk/servicebus/service-bus/src/sender.ts index 2bf64f5aa4cf..9854493ff553 100644 --- a/sdk/servicebus/service-bus/src/sender.ts +++ b/sdk/servicebus/service-bus/src/sender.ts @@ -3,11 +3,12 @@ import Long from "long"; import { MessageSender } from "./core/messageSender"; -import { ServiceBusMessage, isServiceBusMessage } from "./serviceBusMessage"; +import { ServiceBusMessage } from "./serviceBusMessage"; import { ConnectionContext } from "./connectionContext"; import { getSenderClosedErrorMsg, throwErrorIfConnectionClosed, + throwIfNotValidServiceBusMessage, throwTypeErrorIfParameterMissing, throwTypeErrorIfParameterNotLong } from "./util/errors"; @@ -190,16 +191,18 @@ export class ServiceBusSenderImpl implements ServiceBusSender { // link message span contexts let spanContextsToLink: SpanContext[] = []; - if (isServiceBusMessage(messages)) { - messages = [messages]; - } + let batch: ServiceBusMessageBatch; - if (Array.isArray(messages)) { + if (isServiceBusMessageBatch(messages)) { + spanContextsToLink = messages._messageSpanContexts; + batch = messages; + } else { + if (!Array.isArray(messages)) { + messages = [messages]; + } batch = await this.createMessageBatch(options); for (const message of messages) { - if (!isServiceBusMessage(message)) { - throw new TypeError(invalidTypeErrMsg); - } + throwIfNotValidServiceBusMessage(message, invalidTypeErrMsg); if (!batch.tryAddMessage(message, { parentSpan: getParentSpan(options?.tracingOptions) })) { // this is too big - throw an error const error = new MessagingError( @@ -209,11 +212,6 @@ export class ServiceBusSenderImpl implements ServiceBusSender { throw error; } } - } else if (isServiceBusMessageBatch(messages)) { - spanContextsToLink = messages._messageSpanContexts; - batch = messages; - } else { - throw new TypeError(invalidTypeErrMsg); } const sendSpan = createSendSpan( @@ -258,11 +256,10 @@ export class ServiceBusSenderImpl implements ServiceBusSender { const messagesToSchedule = Array.isArray(messages) ? messages : [messages]; for (const message of messagesToSchedule) { - if (!isServiceBusMessage(message)) { - throw new TypeError( - "Provided value for 'messages' must be of type ServiceBusMessage or an array of type ServiceBusMessage." - ); - } + throwIfNotValidServiceBusMessage( + message, + "Provided value for 'messages' must be of type ServiceBusMessage or an array of type ServiceBusMessage." + ); } const scheduleMessageOperationPromise = async () => { diff --git a/sdk/servicebus/service-bus/src/serviceBusMessageBatch.ts b/sdk/servicebus/service-bus/src/serviceBusMessageBatch.ts index e41bb888f3cb..2cc528ea3d37 100644 --- a/sdk/servicebus/service-bus/src/serviceBusMessageBatch.ts +++ b/sdk/servicebus/service-bus/src/serviceBusMessageBatch.ts @@ -4,10 +4,9 @@ import { ServiceBusMessage, toRheaMessage, - isServiceBusMessage, getMessagePropertyTypeMismatchError } from "./serviceBusMessage"; -import { throwTypeErrorIfParameterMissing } from "./util/errors"; +import { throwIfNotValidServiceBusMessage, throwTypeErrorIfParameterMissing } from "./util/errors"; import { ConnectionContext } from "./connectionContext"; import { MessageAnnotations, @@ -246,9 +245,10 @@ export class ServiceBusMessageBatchImpl implements ServiceBusMessageBatch { */ public tryAddMessage(message: ServiceBusMessage, options: TryAddOptions = {}): boolean { throwTypeErrorIfParameterMissing(this._context.connectionId, "message", message); - if (!isServiceBusMessage(message)) { - throw new TypeError("Provided value for 'message' must be of type ServiceBusMessage."); - } + throwIfNotValidServiceBusMessage( + message, + "Provided value for 'message' must be of type ServiceBusMessage." + ); // check if the event has already been instrumented const previouslyInstrumented = Boolean( diff --git a/sdk/servicebus/service-bus/src/util/errors.ts b/sdk/servicebus/service-bus/src/util/errors.ts index f7bb29c336b4..955bf0dacd6d 100644 --- a/sdk/servicebus/service-bus/src/util/errors.ts +++ b/sdk/servicebus/service-bus/src/util/errors.ts @@ -4,7 +4,7 @@ import { logger, receiverLogger } from "../log"; import Long from "long"; import { ConnectionContext } from "../connectionContext"; -import { ServiceBusReceivedMessage } from "../serviceBusMessage"; +import { isServiceBusMessage, ServiceBusReceivedMessage } from "../serviceBusMessage"; import { ReceiveMode } from "../models"; /** @@ -249,3 +249,27 @@ export function throwErrorIfInvalidOperationOnMessage( throw error; } } + +/** + * Error message for when the ServiceBusMessage provided by the user has different values + * for partitionKey and sessionId. + * @internal + * @throw + */ +export const PartitionKeySessionIdMismatchError = + "The fields 'partitionKey' and 'sessionId' cannot have different values."; +/** + * Throws error if the given object is not a valid ServiceBusMessage + * @internal + * @ignore + * @param msg The object that needs to be validated as a ServiceBusMessage + * @param errorMessageForWrongType The error message to use when given object is not a ServiceBusMessage + */ +export function throwIfNotValidServiceBusMessage(msg: any, errorMessageForWrongType: string): void { + if (!isServiceBusMessage(msg)) { + throw new TypeError(errorMessageForWrongType); + } + if (msg.partitionKey && msg.sessionId && msg.partitionKey !== msg.sessionId) { + throw new TypeError(PartitionKeySessionIdMismatchError); + } +} diff --git a/sdk/servicebus/service-bus/test/internal/sender.spec.ts b/sdk/servicebus/service-bus/test/internal/sender.spec.ts index bb13459df1ab..22d229cadbbf 100644 --- a/sdk/servicebus/service-bus/test/internal/sender.spec.ts +++ b/sdk/servicebus/service-bus/test/internal/sender.spec.ts @@ -7,6 +7,7 @@ import { ConnectionContext } from "../../src/connectionContext"; import { ServiceBusMessage } from "../../src"; import { isServiceBusMessageBatch, ServiceBusSenderImpl } from "../../src/sender"; import { createConnectionContextForTests } from "./unittestUtils"; +import { PartitionKeySessionIdMismatchError } from "../../src/util/errors"; const assert = chai.assert; @@ -29,18 +30,38 @@ describe("sender unit tests", () => { return new ServiceBusMessageBatchImpl(fakeContext, 100); }; - ["hello", {}, 123, null, undefined, ["hello"]].forEach((invalidValue) => { + const partitionKeySessionIdMismatchMsg = { + body: "boooo", + sessionId: "my-sessionId", + partitionKey: "my-partitionKey" + }; + const badMessages = [ + "hello", + {}, + 123, + null, + undefined, + ["hello"], + partitionKeySessionIdMismatchMsg + ]; + + badMessages.forEach((invalidValue) => { it(`don't allow Sender.sendMessages(${invalidValue})`, async () => { let expectedErrorMsg = "Provided value for 'messages' must be of type ServiceBusMessage, ServiceBusMessageBatch or an array of type ServiceBusMessage."; if (invalidValue === null || invalidValue === undefined) { expectedErrorMsg = `Missing parameter "messages"`; } + if (invalidValue === partitionKeySessionIdMismatchMsg) { + expectedErrorMsg = PartitionKeySessionIdMismatchError; + } + try { await sender.sendMessages( // @ts-expect-error invalidValue ); + assert.fail("You should not be seeing this."); } catch (err) { assert.equal(err.name, "TypeError"); assert.equal(err.message, expectedErrorMsg); @@ -48,18 +69,23 @@ describe("sender unit tests", () => { }); }); - ["hello", {}, null, undefined].forEach((invalidValue) => { + badMessages.forEach((invalidValue) => { it(`don't allow tryAdd(${invalidValue})`, async () => { const batch = await sender.createMessageBatch(); let expectedErrorMsg = "Provided value for 'message' must be of type ServiceBusMessage."; if (invalidValue === null || invalidValue === undefined) { expectedErrorMsg = `Missing parameter "message"`; } + if (invalidValue === partitionKeySessionIdMismatchMsg) { + expectedErrorMsg = PartitionKeySessionIdMismatchError; + } + try { batch.tryAddMessage( // @ts-expect-error invalidValue ); + assert.fail("You should not be seeing this."); } catch (err) { assert.equal(err.name, "TypeError"); assert.equal(err.message, expectedErrorMsg); @@ -67,13 +93,16 @@ describe("sender unit tests", () => { }); }); - ["hello", {}, null, undefined, ["hello"]].forEach((invalidValue) => { + badMessages.forEach((invalidValue) => { it(`don't allow Sender.scheduleMessages(${invalidValue})`, async () => { let expectedErrorMsg = "Provided value for 'messages' must be of type ServiceBusMessage or an array of type ServiceBusMessage."; if (invalidValue === null || invalidValue === undefined) { expectedErrorMsg = `Missing parameter "messages"`; } + if (invalidValue === partitionKeySessionIdMismatchMsg) { + expectedErrorMsg = PartitionKeySessionIdMismatchError; + } try { await sender.scheduleMessages( @@ -81,6 +110,7 @@ describe("sender unit tests", () => { invalidValue, new Date() ); + assert.fail("You should not be seeing this."); } catch (err) { assert.equal(err.name, "TypeError"); assert.equal(err.message, expectedErrorMsg);