diff --git a/packages/client/lib/tests/test-scenario/connection-handoff.e2e.ts b/packages/client/lib/tests/test-scenario/connection-handoff.e2e.ts index 27e7975691..3c71dff7ec 100644 --- a/packages/client/lib/tests/test-scenario/connection-handoff.e2e.ts +++ b/packages/client/lib/tests/test-scenario/connection-handoff.e2e.ts @@ -6,17 +6,56 @@ import { getEnvConfig, RedisConnectionConfig, } from "./test-scenario.util"; -import { createClient } from "../../.."; +import { createClient, RedisClientOptions } from "../../.."; import { before } from "mocha"; -import { spy } from "sinon"; +import Sinon, { SinonSpy, spy, stub } from "sinon"; import assert from "node:assert"; -import net from "node:net"; + +/** + * Creates a spy on a duplicated client method + * @param client - The Redis client instance + * @param funcName - The name of the method to spy on + * @returns Object containing the promise that resolves with the spy and restore function + */ +const spyOnTemporaryClientInstanceMethod = ( + client: ReturnType>, + methodName: string +) => { + const { promise, resolve } = ( + Promise as typeof Promise & { + withResolvers: () => { + promise: Promise<{ spy: SinonSpy; restore: () => void }>; + resolve: (value: any) => void; + }; + } + ).withResolvers(); + + const originalDuplicate = client.duplicate.bind(client); + + const duplicateStub: Sinon.SinonStub = stub( + // Temporary clients (in the context of hitless upgrade) + // are created by calling the duplicate method on the client. + Object.getPrototypeOf(client), + "duplicate" + ).callsFake((opts) => { + const tmpClient = originalDuplicate(opts); + resolve({ + spy: spy(tmpClient, methodName), + restore: duplicateStub.restore, + }); + + return tmpClient; + }); + + return { + getSpy: () => promise, + }; +}; describe("Connection Handoff", () => { let clientConfig: RedisConnectionConfig; let client: ReturnType>; let faultInjectorClient: FaultInjectorClient; - let connectSpy = spy(net, "createConnection"); before(() => { const envConfig = getEnvConfig(); @@ -28,62 +67,110 @@ describe("Connection Handoff", () => { clientConfig = getDatabaseConfig(redisConfig); }); - beforeEach(async () => { - connectSpy.resetHistory(); - - client = await createTestClient(clientConfig); - - await client.flushAll(); - }); - - afterEach(() => { + afterEach(async () => { if (client && client.isOpen) { + await client.flushAll(); client.destroy(); } }); - describe("New Connection Establishment", () => { - it("should establish new connection", async () => { - assert.equal(connectSpy.callCount, 1); - - const { action_id: lowTimeoutBindAndMigrateActionId } = - await faultInjectorClient.migrateAndBindAction({ - bdbId: clientConfig.bdbId, - clusterIndex: 0, - }); - - await faultInjectorClient.waitForAction(lowTimeoutBindAndMigrateActionId); - - assert.equal(connectSpy.callCount, 2); - }); + describe("New Connection Establishment & Traffic Resumption", () => { + const cases: Array<{ + name: string; + clientOptions: Partial; + }> = [ + { + name: "default options", + clientOptions: {}, + }, + { + name: "external-ip", + clientOptions: { + maintMovingEndpointType: "external-ip", + }, + }, + { + name: "external-fqdn", + clientOptions: { + maintMovingEndpointType: "external-fqdn", + }, + }, + { + name: "auto", + clientOptions: { + maintMovingEndpointType: "auto", + }, + }, + { + name: "none", + clientOptions: { + maintMovingEndpointType: "none", + }, + }, + ]; + + for (const { name, clientOptions } of cases) { + it.only(`should establish new connection and resume traffic afterwards - ${name}`, async () => { + client = await createTestClient(clientConfig, clientOptions); + + const spyObject = spyOnTemporaryClientInstanceMethod(client, "connect"); + + // PART 1 Establish initial connection + const { action_id: lowTimeoutBindAndMigrateActionId } = + await faultInjectorClient.migrateAndBindAction({ + bdbId: clientConfig.bdbId, + clusterIndex: 0, + }); + + await faultInjectorClient.waitForAction( + lowTimeoutBindAndMigrateActionId + ); + + const spyResult = await spyObject.getSpy(); + + assert.strictEqual(spyResult.spy.callCount, 1); + + // PART 2 Verify traffic resumption + const currentTime = Date.now().toString(); + await client.set("key", currentTime); + const result = await client.get("key"); + + assert.strictEqual(result, currentTime); + + spyResult.restore(); + }); + } }); describe("TLS Connection Handoff", () => { - it("TODO receiveMessagesWithTLSEnabledTest", async () => { + it.skip("TODO receiveMessagesWithTLSEnabledTest", async () => { // }); - it("TODO connectionHandoffWithStaticInternalNameTest", async () => { + it.skip("TODO connectionHandoffWithStaticInternalNameTest", async () => { // }); - it("TODO connectionHandoffWithStaticExternalNameTest", async () => { + it.skip("TODO connectionHandoffWithStaticExternalNameTest", async () => { // }); }); - describe("Traffic Resumption", () => { - it("Traffic resumed after handoff", async () => { - const { action_id } = await faultInjectorClient.migrateAndBindAction({ - bdbId: clientConfig.bdbId, - clusterIndex: 0, - }); + describe("Connection Cleanup", () => { + it("should shut down old connection", async () => { + const spyObject = spyOnTemporaryClientInstanceMethod(client, "destroy"); + + const { action_id: lowTimeoutBindAndMigrateActionId } = + await faultInjectorClient.migrateAndBindAction({ + bdbId: clientConfig.bdbId, + clusterIndex: 0, + }); + + await faultInjectorClient.waitForAction(lowTimeoutBindAndMigrateActionId); - await faultInjectorClient.waitForAction(action_id); + const spyResult = await spyObject.getSpy(); - const currentTime = Date.now().toString(); - await client.set("key", currentTime); - const result = await client.get("key"); + assert.equal(spyResult.spy.callCount, 1); - assert.strictEqual(result, currentTime); + spyResult.restore(); }); }); });