diff --git a/src/client.ts b/src/client.ts index 897f6f4f..936d9f6f 100644 --- a/src/client.ts +++ b/src/client.ts @@ -611,7 +611,12 @@ export class Client { if (!chosenNode) { throw new Error(`Stream was not found on any node`) } - const cachedConnection = ConnectionPool.getUsableCachedConnection(purpose, streamName, chosenNode.host) + const cachedConnection = ConnectionPool.getUsableCachedConnection( + purpose, + streamName, + this.connection.vhost, + chosenNode.host + ) if (cachedConnection) return cachedConnection const newConnection = await this.getConnectionOnChosenNode( @@ -622,7 +627,7 @@ export class Client { connectionClosedListener ) - ConnectionPool.cacheConnection(purpose, streamName, newConnection.hostname, newConnection) + ConnectionPool.cacheConnection(purpose, streamName, this.connection.vhost, newConnection.hostname, newConnection) return newConnection } diff --git a/src/connection.ts b/src/connection.ts index ef1457ef..890b4c11 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -62,6 +62,7 @@ export type ConnectionInfo = { port: number id: string ready: boolean + vhost: string readable?: boolean writable?: boolean localPort?: number @@ -78,6 +79,7 @@ type ListenerEntry = { export class Connection { public readonly hostname: string + public readonly vhost: string public readonly leader: boolean public readonly streamName: string | undefined private socket: Socket @@ -109,6 +111,7 @@ export class Connection { private readonly logger: Logger ) { this.hostname = params.hostname + this.vhost = params.vhost this.leader = params.leader ?? false this.streamName = params.streamName if (params.frameMax) this.frameMax = params.frameMax @@ -374,6 +377,7 @@ export class Connection { writable: this.socket.writable, localPort: this.socket.localPort, ready: this.ready, + vhost: this.vhost, } } @@ -484,7 +488,7 @@ export class Connection { } private virtualHostIsNotValid(virtualHost: string) { - if (!virtualHost || virtualHost.split("/").length !== 2) { + if (!virtualHost) { return true } diff --git a/src/connection_pool.ts b/src/connection_pool.ts index ff269362..aa0125d7 100644 --- a/src/connection_pool.ts +++ b/src/connection_pool.ts @@ -8,20 +8,26 @@ export class ConnectionPool { private static consumerConnectionProxies = new Map() private static publisherConnectionProxies = new Map() - public static getUsableCachedConnection(purpose: ConnectionPurpose, streamName: string, host: string) { + public static getUsableCachedConnection(purpose: ConnectionPurpose, streamName: string, vhost: string, host: string) { const map = purpose === "publisher" ? ConnectionPool.publisherConnectionProxies : ConnectionPool.consumerConnectionProxies - const key = ConnectionPool.getCacheKey(streamName, host) + const key = ConnectionPool.getCacheKey(streamName, vhost, host) const proxies = map.get(key) || [] const connection = proxies.at(-1) const refCount = connection?.refCount return refCount !== undefined && refCount < getMaxSharedConnectionInstances() ? connection : undefined } - public static cacheConnection(purpose: ConnectionPurpose, streamName: string, host: string, client: Connection) { + public static cacheConnection( + purpose: ConnectionPurpose, + streamName: string, + vhost: string, + host: string, + client: Connection + ) { const map = purpose === "publisher" ? ConnectionPool.publisherConnectionProxies : ConnectionPool.consumerConnectionProxies - const key = ConnectionPool.getCacheKey(streamName, host) + const key = ConnectionPool.getCacheKey(streamName, vhost, host) const currentlyCached = map.get(key) || [] currentlyCached.push(client) map.set(key, currentlyCached) @@ -36,10 +42,10 @@ export class ConnectionPool { } public static removeCachedConnection(connection: Connection) { - const { leader, streamName, hostname: host } = connection + const { leader, streamName, hostname: host, vhost } = connection if (streamName === undefined) return const m = leader ? ConnectionPool.publisherConnectionProxies : ConnectionPool.consumerConnectionProxies - const k = ConnectionPool.getCacheKey(streamName, host) + const k = ConnectionPool.getCacheKey(streamName, vhost, host) const mappedClientList = m.get(k) if (mappedClientList) { const filtered = mappedClientList.filter((c) => c !== connection) @@ -47,7 +53,7 @@ export class ConnectionPool { } } - private static getCacheKey(streamName: string, host: string) { - return `${streamName}@${host}` + private static getCacheKey(streamName: string, vhost: string, host: string) { + return `${streamName}@${vhost}@${host}` } } diff --git a/src/consumer.ts b/src/consumer.ts index 0197a3f5..8fc85bc9 100644 --- a/src/consumer.ts +++ b/src/consumer.ts @@ -76,8 +76,8 @@ export class StreamConsumer implements Consumer { } public getConnectionInfo(): ConnectionInfo { - const { host, port, id, readable, localPort, ready } = this.connection.getConnectionInfo() - return { host, port, id, readable, localPort, ready } + const { host, port, id, readable, localPort, ready, vhost } = this.connection.getConnectionInfo() + return { host, port, id, readable, localPort, ready, vhost } } public get localOffset() { diff --git a/src/publisher.ts b/src/publisher.ts index 4e8d7825..5bc283d1 100644 --- a/src/publisher.ts +++ b/src/publisher.ts @@ -180,8 +180,8 @@ export class StreamPublisher implements Publisher { } public getConnectionInfo(): ConnectionInfo { - const { host, port, id, writable, localPort, ready } = this.connection.getConnectionInfo() - return { host, port, id, writable, localPort, ready } + const { host, port, id, writable, localPort, ready, vhost } = this.connection.getConnectionInfo() + return { host, port, id, writable, localPort, ready, vhost } } public on(event: "metadata_update", listener: MetadataUpdateListener): void diff --git a/test/e2e/stream_cache.test.ts b/test/e2e/stream_cache.test.ts new file mode 100644 index 00000000..45a18833 --- /dev/null +++ b/test/e2e/stream_cache.test.ts @@ -0,0 +1,78 @@ +import { expect } from "chai" +import got from "got" +import { Client } from "../../src" +import { createClient, createStreamName } from "../support/fake_data" +import { Rabbit, RabbitConnectionResponse } from "../support/rabbit" +import { getTestNodesFromEnv, password, username } from "../support/util" + +async function createVhost(vhost: string): Promise { + const uriVhost = encodeURIComponent(vhost) + const port = process.env.RABBIT_MQ_MANAGEMENT_PORT || 15672 + const firstNode = getTestNodesFromEnv().shift()! + await got.put(`http://${firstNode.host}:${port}/api/vhosts/${uriVhost}`, { + username: username, + password: password, + }) + await got + .put(`http://${firstNode.host}:${port}/api/permissions/${uriVhost}/${username}`, { + json: { + read: ".*", + write: ".*", + configure: ".*", + }, + username: username, + password: password, + }) + .json() +} + +async function deleteVhost(vhost: string): Promise { + const uriVhost = encodeURIComponent(vhost) + const port = process.env.RABBIT_MQ_MANAGEMENT_PORT || 15672 + const firstNode = getTestNodesFromEnv().shift()! + const r = await got.delete(`http://${firstNode.host}:${port}/api/vhosts/${uriVhost}`, { + username: username, + password: password, + }) + + return r.body +} + +describe("cache", () => { + const vhost1 = "vhost1" + let streamName: string + const rabbit = new Rabbit(username, password) + let client: Client + let client2: Client + before(async () => { + await createVhost(vhost1) + }) + beforeEach(async () => { + client = await createClient(username, password) + client2 = await createClient(username, password, undefined, undefined, undefined, undefined, undefined, vhost1) + streamName = createStreamName() + await client.createStream({ stream: streamName }) + await client2.createStream({ stream: streamName }) + }) + afterEach(async () => { + try { + await client.close() + await client2.close() + await deleteVhost(vhost1) + await rabbit.deleteStream(streamName) + await rabbit.closeAllConnections() + await rabbit.deleteAllQueues({ match: /my-stream-/ }) + } catch (_e) {} + }) + + it("should cache using the vhost as well as the stream name", async () => { + const publisher1 = await client.declarePublisher({ + stream: streamName, + }) + expect(publisher1.getConnectionInfo().vhost).eql("/") + const publisher2 = await client2.declarePublisher({ + stream: streamName, + }) + expect(publisher2.getConnectionInfo().vhost).eql(vhost1) + }) +}) diff --git a/test/support/fake_data.ts b/test/support/fake_data.ts index dc34c910..78c328f2 100644 --- a/test/support/fake_data.ts +++ b/test/support/fake_data.ts @@ -65,7 +65,8 @@ export async function createClient( frameMax?: number, bufferSizeSettings?: BufferSizeSettings, port?: number, - connectionName?: string + connectionName?: string, + vhost?: string ): Promise { const [firstNode] = getTestNodesFromEnv() return connect( @@ -74,7 +75,7 @@ export async function createClient( port: port ?? firstNode.port, username, password, - vhost: "/", + vhost: vhost ?? "/", frameMax: frameMax ?? 0, heartbeat: 0, listeners: listeners,