diff --git a/src/client/factory.ts b/src/client/factory.ts index 70943aec..61b5f736 100644 --- a/src/client/factory.ts +++ b/src/client/factory.ts @@ -75,7 +75,7 @@ export const ClientFactoryOptions = { }; export class ClientFactory { - private readonly transportsByName: Map; + private readonly transportsByName: CaseInsensitiveMap; private readonly agentCardResolver: AgentCardResolver; constructor(public readonly options: ClientFactoryOptions = ClientFactoryOptions.default) { @@ -84,8 +84,7 @@ export class ClientFactory { } this.transportsByName = transportsByName(options.transports); for (const transport of options.preferredTransports ?? []) { - const factory = this.options.transports.find((t) => t.protocolName === transport); - if (!factory) { + if (!this.transportsByName.has(transport)) { throw new Error( `Unknown preferred transport: ${transport}, available transports: ${[...this.transportsByName.keys()].join()}` ); @@ -100,7 +99,7 @@ export class ClientFactory { async createFromAgentCard(agentCard: AgentCard): Promise { const agentCardPreferred = agentCard.preferredTransport ?? JsonRpcTransportFactory.name; const additionalInterfaces = agentCard.additionalInterfaces ?? []; - const urlsPerAgentTransports = new Map([ + const urlsPerAgentTransports = new CaseInsensitiveMap([ [agentCardPreferred, agentCard.url], ...additionalInterfaces.map<[string, string]>((i) => [i.transport, i.url]), ]); @@ -165,8 +164,8 @@ function mergeTransports( function transportsByName( transports: ReadonlyArray | undefined -): Map { - const result = new Map(); +): CaseInsensitiveMap { + const result = new CaseInsensitiveMap(); if (!transports) { return result; } @@ -189,3 +188,29 @@ function mergeArrays( return [...(a1 ?? []), ...(a2 ?? [])]; } + +/** + * A Map that normalizes string keys to uppercase for case-insensitive lookups. + * This prevents errors from inconsistent casing in protocol names. + */ +class CaseInsensitiveMap extends Map { + private normalizeKey(key: string): string { + return key.toUpperCase(); + } + + override set(key: string, value: T): this { + return super.set(this.normalizeKey(key), value); + } + + override get(key: string): T | undefined { + return super.get(this.normalizeKey(key)); + } + + override has(key: string): boolean { + return super.has(this.normalizeKey(key)); + } + + override delete(key: string): boolean { + return super.delete(this.normalizeKey(key)); + } +} diff --git a/test/client/factory.spec.ts b/test/client/factory.spec.ts index b4fa52b1..96f76872 100644 --- a/test/client/factory.spec.ts +++ b/test/client/factory.spec.ts @@ -50,7 +50,7 @@ describe('ClientFactory', () => { preferredTransports: ['UnknownTransport'], }; expect(() => new ClientFactory(options)).to.throw( - 'Unknown preferred transport: UnknownTransport, available transports: Transport1' + 'Unknown preferred transport: UnknownTransport, available transports: TRANSPORT1' ); }); @@ -71,6 +71,30 @@ describe('ClientFactory', () => { expect(factory.options).to.equal(options); }); + + it('should accept preferred transport with different case', () => { + const options: ClientFactoryOptions = { + transports: [mockTransportFactory1], + preferredTransports: ['transport1'], // lowercase, but Transport1 is registered + }; + + // Should not throw + const factory = new ClientFactory(options); + + expect(factory.options).to.equal(options); + }); + + it('should detect duplicate transports with different case as duplicates', () => { + const transport1Lower = { + protocolName: 'transport1', // lowercase + create: vi.fn(), + }; + const options: ClientFactoryOptions = { + transports: [mockTransportFactory1, transport1Lower], // Transport1 and transport1 + }; + + expect(() => new ClientFactory(options)).to.throw('Duplicate protocol name: transport1'); + }); }); describe('createClient', () => { @@ -163,6 +187,46 @@ describe('ClientFactory', () => { expect(client.config).to.equal(clientConfig); }); + it('should match transport with case-insensitive protocol name', async () => { + // Transport factory uses "Transport1" but agent card uses "transport1" (lowercase) + agentCard.preferredTransport = 'transport1'; + const factory = new ClientFactory({ transports: [mockTransportFactory1] }); + + const client = await factory.createFromAgentCard(agentCard); + + expect(client).to.be.instanceOf(Client); + expect(mockTransportFactory1.create).toHaveBeenCalledExactlyOnceWith( + 'http://transport1.com', + agentCard + ); + }); + + it('should match HTTP+JSON transport regardless of case', async () => { + const httpJsonFactory = { + protocolName: 'HTTP+JSON', + create: vi.fn().mockResolvedValue(mockTransport), + }; + agentCard.preferredTransport = 'http+json'; // lowercase + const factory = new ClientFactory({ transports: [httpJsonFactory] }); + + await factory.createFromAgentCard(agentCard); + + expect(httpJsonFactory.create).toHaveBeenCalledTimes(1); + }); + + it('should match JSONRPC transport regardless of case', async () => { + const jsonRpcFactory = { + protocolName: 'JSONRPC', + create: vi.fn().mockResolvedValue(mockTransport), + }; + agentCard.preferredTransport = 'JsonRpc'; // mixed case + const factory = new ClientFactory({ transports: [jsonRpcFactory] }); + + await factory.createFromAgentCard(agentCard); + + expect(jsonRpcFactory.create).toHaveBeenCalledTimes(1); + }); + it('should use card resolver with default path', async () => { const cardResolver = { resolve: vi.fn().mockResolvedValue(agentCard),