diff --git a/src/server/index.ts b/src/server/index.ts index d845707e..84b3393e 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -14,6 +14,7 @@ export { ExecutionEventQueue } from './events/execution_event_queue.js'; export type { A2ARequestHandler } from './request_handler/a2a_request_handler.js'; export { DefaultRequestHandler } from './request_handler/default_request_handler.js'; +export type { ExtendedAgentCardProvider } from './request_handler/default_request_handler.js'; export { ResultManager } from './result_manager.js'; export type { TaskStore } from './store.js'; export { InMemoryTaskStore } from './store.js'; diff --git a/src/server/request_handler/a2a_request_handler.ts b/src/server/request_handler/a2a_request_handler.ts index c31aefbc..ec4d386a 100644 --- a/src/server/request_handler/a2a_request_handler.ts +++ b/src/server/request_handler/a2a_request_handler.ts @@ -17,7 +17,7 @@ import { ServerCallContext } from '../context.js'; export interface A2ARequestHandler { getAgentCard(): Promise; - getAuthenticatedExtendedAgentCard(): Promise; + getAuthenticatedExtendedAgentCard(context?: ServerCallContext): Promise; sendMessage(params: MessageSendParams, context?: ServerCallContext): Promise; diff --git a/src/server/request_handler/default_request_handler.ts b/src/server/request_handler/default_request_handler.ts index ac8afebb..3b52be37 100644 --- a/src/server/request_handler/default_request_handler.ts +++ b/src/server/request_handler/default_request_handler.ts @@ -39,12 +39,12 @@ const terminalStates: TaskState[] = ['completed', 'failed', 'canceled', 'rejecte export class DefaultRequestHandler implements A2ARequestHandler { private readonly agentCard: AgentCard; - private readonly extendedAgentCard?: AgentCard; private readonly taskStore: TaskStore; private readonly agentExecutor: AgentExecutor; private readonly eventBusManager: ExecutionEventBusManager; private readonly pushNotificationStore?: PushNotificationStore; private readonly pushNotificationSender?: PushNotificationSender; + private readonly extendedAgentCardProvider?: AgentCard | ExtendedAgentCardProvider; constructor( agentCard: AgentCard, @@ -53,13 +53,13 @@ export class DefaultRequestHandler implements A2ARequestHandler { eventBusManager: ExecutionEventBusManager = new DefaultExecutionEventBusManager(), pushNotificationStore?: PushNotificationStore, pushNotificationSender?: PushNotificationSender, - extendedAgentCard?: AgentCard + extendedAgentCardProvider?: AgentCard | ExtendedAgentCardProvider ) { this.agentCard = agentCard; this.taskStore = taskStore; this.agentExecutor = agentExecutor; this.eventBusManager = eventBusManager; - this.extendedAgentCard = extendedAgentCard; + this.extendedAgentCardProvider = extendedAgentCardProvider; // If push notifications are supported, use the provided store and sender. // Otherwise, use the default in-memory store and sender. @@ -74,12 +74,20 @@ export class DefaultRequestHandler implements A2ARequestHandler { return this.agentCard; } - async getAuthenticatedExtendedAgentCard(): Promise { - if (!this.extendedAgentCard) { + async getAuthenticatedExtendedAgentCard(context?: ServerCallContext): Promise { + if (!this.agentCard.supportsAuthenticatedExtendedCard) { + throw A2AError.unsupportedOperation('Agent does not support authenticated extended card.'); + } + if (!this.extendedAgentCardProvider) { throw A2AError.authenticatedExtendedCardNotConfigured(); } - - return this.extendedAgentCard; + if (typeof this.extendedAgentCardProvider === 'function') { + return this.extendedAgentCardProvider(context); + } + if (context?.user?.isAuthenticated()) { + return this.extendedAgentCardProvider; + } + return this.agentCard; } private async _createRequestContext( @@ -681,3 +689,5 @@ export class DefaultRequestHandler implements A2ARequestHandler { } } } + +export type ExtendedAgentCardProvider = (context?: ServerCallContext) => Promise; diff --git a/src/server/transports/jsonrpc_transport_handler.ts b/src/server/transports/jsonrpc_transport_handler.ts index 423ba98e..072e1c55 100644 --- a/src/server/transports/jsonrpc_transport_handler.ts +++ b/src/server/transports/jsonrpc_transport_handler.ts @@ -60,16 +60,10 @@ export class JsonRpcTransportHandler { const { method, id: requestId = null } = rpcRequest; try { - if (method === 'agent/getAuthenticatedExtendedCard') { - const result = await this.requestHandler.getAuthenticatedExtendedAgentCard(); - return { - jsonrpc: '2.0', - id: requestId, - result: result, - } as JSONRPCResponse; - } - - if (!this.paramsAreValid(rpcRequest.params)) { + if ( + method !== 'agent/getAuthenticatedExtendedCard' && + !this.paramsAreValid(rpcRequest.params) + ) { throw A2AError.invalidParams(`Invalid method parameters.`); } @@ -148,6 +142,9 @@ export class JsonRpcTransportHandler { context ); break; + case 'agent/getAuthenticatedExtendedCard': + result = await this.requestHandler.getAuthenticatedExtendedAgentCard(context); + break; default: throw A2AError.methodNotFound(method); } diff --git a/test/server/default_request_handler.spec.ts b/test/server/default_request_handler.spec.ts index bd1f8499..111d9925 100644 --- a/test/server/default_request_handler.spec.ts +++ b/test/server/default_request_handler.spec.ts @@ -12,6 +12,8 @@ import { InMemoryPushNotificationStore, RequestContext, ExecutionEventBus, + ExtendedAgentCardProvider, + User, } from '../../src/server/index.js'; import { AgentCard, @@ -1729,4 +1731,142 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { 'requestedExtensions should contain the expected extension' ); }); + + describe('getAuthenticatedExtendedAgentCard tests', async () => { + class A2AUser implements User { + constructor(private _isAuthenticated: boolean) {} + + isAuthenticated(): boolean { + return this._isAuthenticated; + } + userName(): string { + return 'test-user'; + } + } + + const extendedAgentcardProvider: ExtendedAgentCardProvider = async (context?) => { + if (context?.user?.isAuthenticated()) { + return extendedAgentCard; + } + // Remove the extensions that are not allowed for unauthenticated clients + extendedAgentCard.capabilities.extensions = [{ uri: 'requested-extension-uri' }]; + return extendedAgentCard; + }; + + const agentCardWithExtendedSupport: AgentCard = { + name: 'Test Agent', + description: 'An agent for testing purposes', + url: 'http://localhost:8080', + version: '1.0.0', + protocolVersion: '0.3.0', + capabilities: { + extensions: [{ uri: 'requested-extension-uri' }], + streaming: true, + pushNotifications: true, + }, + defaultInputModes: ['text/plain'], + defaultOutputModes: ['text/plain'], + skills: [ + { + id: 'test-skill', + name: 'Test Skill', + description: 'A skill for testing', + tags: ['test'], + }, + ], + supportsAuthenticatedExtendedCard: true, + }; + + const extendedAgentCard: AgentCard = { + name: 'Test ExtendedAgentCard Agent', + description: 'An agent for testing the extended agent card functionality', + url: 'http://localhost:8080', + version: '1.0.0', + protocolVersion: '0.3.0', + capabilities: { + extensions: [ + { uri: 'requested-extension-uri' }, + { uri: 'extension-uri-for-authenticated-clients' }, + ], + streaming: true, + pushNotifications: true, + }, + defaultInputModes: ['text/plain'], + defaultOutputModes: ['text/plain'], + skills: [ + { + id: 'test-skill', + name: 'Test Skill', + description: 'A skill for testing', + tags: ['test'], + }, + ], + }; + + it('getAuthenticatedExtendedAgentCard should fail if the agent card does not support extended agent card', async () => { + let caughtError; + try { + await handler.getAuthenticatedExtendedAgentCard(); + } catch (error: any) { + caughtError = error; + } finally { + expect(caughtError).to.be.instanceOf(A2AError); + expect(caughtError.code).to.equal(-32004); + expect(caughtError.message).to.contain('Unsupported operation'); + } + }); + + it('getAuthenticatedExtendedAgentCard should fail if ExtendedAgentCardProvider is not provided', async () => { + handler = new DefaultRequestHandler( + agentCardWithExtendedSupport, + mockTaskStore, + mockAgentExecutor, + executionEventBusManager + ); + let caughtError; + try { + await handler.getAuthenticatedExtendedAgentCard(); + } catch (error: any) { + caughtError = error; + } finally { + expect(caughtError).to.be.instanceOf(A2AError); + expect(caughtError.code).to.equal(-32007); + expect(caughtError.message).to.contain('Extended card not configured'); + } + }); + + it('getAuthenticatedExtendedAgentCard should return extended card if user is authenticated with ExtendedAgentCardProvider as AgentCard', async () => { + handler = new DefaultRequestHandler( + agentCardWithExtendedSupport, + mockTaskStore, + mockAgentExecutor, + executionEventBusManager, + undefined, + undefined, + extendedAgentCard + ); + + const context = new ServerCallContext(undefined, new A2AUser(true)); + const agentCard = await handler.getAuthenticatedExtendedAgentCard(context); + assert.deepEqual(agentCard, extendedAgentCard); + }); + + it('getAuthenticatedExtendedAgentCard should return capped extended card if user is not authenticated with ExtendedAgentCardProvider as callback', async () => { + handler = new DefaultRequestHandler( + agentCardWithExtendedSupport, + mockTaskStore, + mockAgentExecutor, + executionEventBusManager, + undefined, + undefined, + extendedAgentcardProvider + ); + + const context = new ServerCallContext(undefined, new A2AUser(false)); + const agentCard = await handler.getAuthenticatedExtendedAgentCard(context); + assert(agentCard.capabilities.extensions.length === 1); + assert.deepEqual(agentCard.capabilities.extensions[0], { uri: 'requested-extension-uri' }); + assert.deepEqual(agentCard.name, extendedAgentCard.name); + }); + }); });