diff --git a/src/server/authentication/user.ts b/src/server/authentication/user.ts index 6d26013e..6bf7bb40 100644 --- a/src/server/authentication/user.ts +++ b/src/server/authentication/user.ts @@ -1,8 +1,21 @@ +/** + * Represents a user accessing A2A server. + */ export interface User { + /** + * Indicates whether the user is authenticated. + */ get isAuthenticated(): boolean; + + /** + * A unique name (identifier) for the user. + */ get userName(): string; } +/** + * An implementation of {@link User} representing an unauthenticated user. + */ export class UnauthenticatedUser implements User { get isAuthenticated(): boolean { return false; diff --git a/src/server/request_handler/default_request_handler.ts b/src/server/request_handler/default_request_handler.ts index 6ba2f7da..360f587b 100644 --- a/src/server/request_handler/default_request_handler.ts +++ b/src/server/request_handler/default_request_handler.ts @@ -99,7 +99,7 @@ export class DefaultRequestHandler implements A2ARequestHandler { // incomingMessage would contain taskId, if a task already exists. if (incomingMessage.taskId) { - task = await this.taskStore.load(incomingMessage.taskId); + task = await this.taskStore.load(incomingMessage.taskId, context); if (!task) { throw A2AError.taskNotFound(incomingMessage.taskId); } @@ -111,7 +111,7 @@ export class DefaultRequestHandler implements A2ARequestHandler { } // Add incomingMessage to history and save the task. task.history = [...(task.history || []), incomingMessage]; - await this.taskStore.save(task); + await this.taskStore.save(task, context); } // Ensure taskId is present const taskId = incomingMessage.taskId || uuidv4(); @@ -119,7 +119,7 @@ export class DefaultRequestHandler implements A2ARequestHandler { if (incomingMessage.referenceTaskIds && incomingMessage.referenceTaskIds.length > 0) { referenceTasks = []; for (const refId of incomingMessage.referenceTaskIds) { - const refTask = await this.taskStore.load(refId); + const refTask = await this.taskStore.load(refId, context); if (refTask) { referenceTasks.push(refTask); } else { @@ -157,6 +157,7 @@ export class DefaultRequestHandler implements A2ARequestHandler { taskId: string, resultManager: ResultManager, eventQueue: ExecutionEventQueue, + context: ServerCallContext | undefined, options?: { firstResultResolver?: (value: Message | Task | PromiseLike) => void; firstResultRejector?: (reason?: unknown) => void; @@ -168,7 +169,7 @@ export class DefaultRequestHandler implements A2ARequestHandler { await resultManager.processEvent(event); try { - await this._sendPushNotificationIfNeeded(event); + await this._sendPushNotificationIfNeeded(event, context); } catch (error) { console.error(`Error sending push notification: ${error}`); } @@ -217,7 +218,7 @@ export class DefaultRequestHandler implements A2ARequestHandler { // Default to blocking behavior if 'blocking' is not explicitly false. const isBlocking = params.configuration?.blocking !== false; // Instantiate ResultManager before creating RequestContext - const resultManager = new ResultManager(this.taskStore); + const resultManager = new ResultManager(this.taskStore, context); resultManager.setContext(incomingMessage); // Set context for ResultManager const requestContext = await this._createRequestContext(incomingMessage, context); @@ -282,7 +283,7 @@ export class DefaultRequestHandler implements A2ARequestHandler { if (isBlocking) { // In blocking mode, wait for the full processing to complete. - await this._processEvents(taskId, resultManager, eventQueue); + await this._processEvents(taskId, resultManager, eventQueue, context); const finalResult = resultManager.getFinalResult(); if (!finalResult) { throw A2AError.internalError( @@ -294,7 +295,7 @@ export class DefaultRequestHandler implements A2ARequestHandler { } else { // In non-blocking mode, return a promise that will be settled by fullProcessing. return new Promise((resolve, reject) => { - this._processEvents(taskId, resultManager, eventQueue, { + this._processEvents(taskId, resultManager, eventQueue, context, { firstResultResolver: resolve, firstResultRejector: reject, }); @@ -318,7 +319,7 @@ export class DefaultRequestHandler implements A2ARequestHandler { } // Instantiate ResultManager before creating RequestContext - const resultManager = new ResultManager(this.taskStore); + const resultManager = new ResultManager(this.taskStore, context); resultManager.setContext(incomingMessage); // Set context for ResultManager const requestContext = await this._createRequestContext(incomingMessage, context); @@ -367,7 +368,7 @@ export class DefaultRequestHandler implements A2ARequestHandler { try { for await (const event of eventQueue.events()) { await resultManager.processEvent(event); // Update store in background - await this._sendPushNotificationIfNeeded(event); + await this._sendPushNotificationIfNeeded(event, context); yield event; // Stream the event to the client } } finally { @@ -376,8 +377,8 @@ export class DefaultRequestHandler implements A2ARequestHandler { } } - async getTask(params: TaskQueryParams, _context?: ServerCallContext): Promise { - const task = await this.taskStore.load(params.id); + async getTask(params: TaskQueryParams, context?: ServerCallContext): Promise { + const task = await this.taskStore.load(params.id, context); if (!task) { throw A2AError.taskNotFound(params.id); } @@ -392,8 +393,8 @@ export class DefaultRequestHandler implements A2ARequestHandler { return task; } - async cancelTask(params: TaskIdParams, _context?: ServerCallContext): Promise { - const task = await this.taskStore.load(params.id); + async cancelTask(params: TaskIdParams, context?: ServerCallContext): Promise { + const task = await this.taskStore.load(params.id, context); if (!task) { throw A2AError.taskNotFound(params.id); } @@ -410,7 +411,12 @@ export class DefaultRequestHandler implements A2ARequestHandler { const eventQueue = new ExecutionEventQueue(eventBus); await this.agentExecutor.cancelTask(params.id, eventBus); // Consume all the events until the task reaches a terminal state. - await this._processEvents(params.id, new ResultManager(this.taskStore), eventQueue); + await this._processEvents( + params.id, + new ResultManager(this.taskStore, context), + eventQueue, + context + ); } else { // Here we are marking task as cancelled. We are not waiting for the executor to actually cancel processing. task.status = { @@ -429,10 +435,10 @@ export class DefaultRequestHandler implements A2ARequestHandler { // Add cancellation message to history task.history = [...(task.history || []), task.status.message]; - await this.taskStore.save(task); + await this.taskStore.save(task, context); } - const latestTask = await this.taskStore.load(params.id); + const latestTask = await this.taskStore.load(params.id, context); if (!latestTask) { throw A2AError.internalError(`Task ${params.id} not found after cancellation.`); } @@ -444,12 +450,12 @@ export class DefaultRequestHandler implements A2ARequestHandler { async setTaskPushNotificationConfig( params: TaskPushNotificationConfig, - _context?: ServerCallContext + context?: ServerCallContext ): Promise { if (!this.agentCard.capabilities.pushNotifications) { throw A2AError.pushNotificationNotSupported(); } - const task = await this.taskStore.load(params.taskId); + const task = await this.taskStore.load(params.taskId, context); if (!task) { throw A2AError.taskNotFound(params.taskId); } @@ -468,12 +474,12 @@ export class DefaultRequestHandler implements A2ARequestHandler { async getTaskPushNotificationConfig( params: TaskIdParams | GetTaskPushNotificationConfigParams, - _context?: ServerCallContext + context?: ServerCallContext ): Promise { if (!this.agentCard.capabilities.pushNotifications) { throw A2AError.pushNotificationNotSupported(); } - const task = await this.taskStore.load(params.id); + const task = await this.taskStore.load(params.id, context); if (!task) { throw A2AError.taskNotFound(params.id); } @@ -503,12 +509,12 @@ export class DefaultRequestHandler implements A2ARequestHandler { async listTaskPushNotificationConfigs( params: ListTaskPushNotificationConfigParams, - _context?: ServerCallContext + context?: ServerCallContext ): Promise { if (!this.agentCard.capabilities.pushNotifications) { throw A2AError.pushNotificationNotSupported(); } - const task = await this.taskStore.load(params.id); + const task = await this.taskStore.load(params.id, context); if (!task) { throw A2AError.taskNotFound(params.id); } @@ -523,12 +529,12 @@ export class DefaultRequestHandler implements A2ARequestHandler { async deleteTaskPushNotificationConfig( params: DeleteTaskPushNotificationConfigParams, - _context?: ServerCallContext + context?: ServerCallContext ): Promise { if (!this.agentCard.capabilities.pushNotifications) { throw A2AError.pushNotificationNotSupported(); } - const task = await this.taskStore.load(params.id); + const task = await this.taskStore.load(params.id, context); if (!task) { throw A2AError.taskNotFound(params.id); } @@ -540,7 +546,7 @@ export class DefaultRequestHandler implements A2ARequestHandler { async *resubscribe( params: TaskIdParams, - _context?: ServerCallContext + context?: ServerCallContext ): AsyncGenerator< | Task // Initial task state | TaskStatusUpdateEvent @@ -552,7 +558,7 @@ export class DefaultRequestHandler implements A2ARequestHandler { throw A2AError.unsupportedOperation('Streaming (and thus resubscription) is not supported.'); } - const task = await this.taskStore.load(params.id); + const task = await this.taskStore.load(params.id, context); if (!task) { throw A2AError.taskNotFound(params.id); } @@ -600,7 +606,10 @@ export class DefaultRequestHandler implements A2ARequestHandler { } } - private async _sendPushNotificationIfNeeded(event: AgentExecutionEvent): Promise { + private async _sendPushNotificationIfNeeded( + event: AgentExecutionEvent, + context: ServerCallContext | undefined + ): Promise { if (!this.agentCard.capabilities.pushNotifications) { return; } @@ -618,7 +627,7 @@ export class DefaultRequestHandler implements A2ARequestHandler { return; } - const task = await this.taskStore.load(taskId); + const task = await this.taskStore.load(taskId, context); if (!task) { console.error(`Task ${taskId} not found.`); return; diff --git a/src/server/result_manager.ts b/src/server/result_manager.ts index adf8838b..9b1991f9 100644 --- a/src/server/result_manager.ts +++ b/src/server/result_manager.ts @@ -1,15 +1,19 @@ import { Message, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent } from '../types.js'; +import { ServerCallContext } from './context.js'; import { AgentExecutionEvent } from './events/execution_event_bus.js'; import { TaskStore } from './store.js'; export class ResultManager { - private taskStore: TaskStore; + private readonly taskStore: TaskStore; + private readonly serverCallContext?: ServerCallContext; + private currentTask?: Task; private latestUserMessage?: Message; // To add to history if a new task is created private finalMessageResult?: Message; // Stores the message if it's the final result - constructor(taskStore: TaskStore) { + constructor(taskStore: TaskStore, serverCallContext?: ServerCallContext) { this.taskStore = taskStore; + this.serverCallContext = serverCallContext; } public setContext(latestUserMessage: Message): void { @@ -62,7 +66,7 @@ export class ResultManager { } else if (!this.currentTask && updateEvent.taskId) { // Potentially an update for a task we haven't seen the 'task' event for yet, // or we are rehydrating. Attempt to load. - const loaded = await this.taskStore.load(updateEvent.taskId); + const loaded = await this.taskStore.load(updateEvent.taskId, this.serverCallContext); if (loaded) { this.currentTask = loaded; this.currentTask.status = updateEvent.status; @@ -119,7 +123,7 @@ export class ResultManager { await this.saveCurrentTask(); } else if (!this.currentTask && artifactEvent.taskId) { // Similar to status update, try to load if task not in memory - const loaded = await this.taskStore.load(artifactEvent.taskId); + const loaded = await this.taskStore.load(artifactEvent.taskId, this.serverCallContext); if (loaded) { this.currentTask = loaded; if (!this.currentTask.artifacts) this.currentTask.artifacts = []; @@ -150,7 +154,7 @@ export class ResultManager { private async saveCurrentTask(): Promise { if (this.currentTask) { - await this.taskStore.save(this.currentTask); + await this.taskStore.save(this.currentTask, this.serverCallContext); } } diff --git a/src/server/store.ts b/src/server/store.ts index 48f52c9e..1510aae7 100644 --- a/src/server/store.ts +++ b/src/server/store.ts @@ -1,4 +1,5 @@ import { Task } from '../types.js'; +import { ServerCallContext } from './context.js'; /** * Simplified interface for task storage providers. @@ -8,17 +9,19 @@ export interface TaskStore { /** * Saves a task. * Overwrites existing data if the task ID exists. - * @param data An object containing the task. + * @param task The task to save. + * @param context The context of the current call. * @returns A promise resolving when the save operation is complete. */ - save(task: Task): Promise; + save(task: Task, context?: ServerCallContext): Promise; /** * Loads a task by task ID. * @param taskId The ID of the task to load. + * @param context The context of the current call. * @returns A promise resolving to an object containing the Task, or undefined if not found. */ - load(taskId: string): Promise; + load(taskId: string, context?: ServerCallContext): Promise; } // ======================== diff --git a/test/server/default_request_handler.spec.ts b/test/server/default_request_handler.spec.ts index 0a138572..dcdb104e 100644 --- a/test/server/default_request_handler.spec.ts +++ b/test/server/default_request_handler.spec.ts @@ -76,9 +76,27 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { ], }; + const serverCallContext = new ServerCallContext(); + // Before each test, reset the components to a clean state beforeEach(() => { - mockTaskStore = new InMemoryTaskStore(); + // Wrap in-memory store into a store which ensures we pass server call context. + // The parameter is optional to avoid breaking changes, however it should be passed. + const inMemoryStore = new InMemoryTaskStore(); + mockTaskStore = { + save: async (task: Task, ctx?: ServerCallContext) => { + if (!ctx) { + throw new Error('Missing server call context'); + } + return inMemoryStore.save(task); + }, + load: async (id: string, ctx?: ServerCallContext) => { + if (!ctx) { + throw new Error('Missing server call context'); + } + return inMemoryStore.load(id); + }, + }; // Default mock for most tests mockAgentExecutor = new MockAgentExecutor(); executionEventBusManager = new DefaultExecutionEventBusManager(); @@ -123,7 +141,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { bus.finished(); }); - const result = await handler.sendMessage(params); + const result = await handler.sendMessage(params, serverCallContext); assert.deepEqual(result, agentResponse, "The result should be the agent's message"); assert.isTrue( @@ -184,7 +202,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { bus.finished(); }); - const result = await handler.sendMessage(params); + const result = await handler.sendMessage(params, serverCallContext); const taskResult = result as Task; assert.equal(taskResult.kind, 'task'); @@ -205,7 +223,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { message: createTestMessage('msg-fail-block', 'Test failure blocking'), }; - const blockingResult = await handler.sendMessage(blockingParams); + const blockingResult = await handler.sendMessage(blockingParams, serverCallContext); const blockingTask = blockingResult as Task; assert.equal(blockingTask.kind, 'task', 'Result should be a task'); assert.equal(blockingTask.status.state, 'failed', 'Task status should be failed'); @@ -251,7 +269,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { }); // This call should return as soon as the first 'task' event is published - const immediateResult = await handler.sendMessage(params); + const immediateResult = await handler.sendMessage(params, serverCallContext); // Assert that we got the initial task object back right away const taskResult = immediateResult as Task; @@ -271,7 +289,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { await clock.runAllAsync(); // Now, check the final state in the store to ensure background processing finished - const finalTask = await mockTaskStore.load(taskId); + const finalTask = await mockTaskStore.load(taskId, serverCallContext); assert.isDefined(finalTask); assert.equal( finalTask!.status.state, @@ -338,7 +356,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { }); // This call should return as soon as the first 'task' event is published - const immediateResult = await handler.sendMessage(params); + const immediateResult = await handler.sendMessage(params, serverCallContext); // Assert that we got the initial task object back right away const taskResult = immediateResult as Task; @@ -373,7 +391,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { configuration: { blocking: false, acceptedOutputModes: [] }, }; - const nonBlockingResult = await handler.sendMessage(nonBlockingParams); + const nonBlockingResult = await handler.sendMessage(nonBlockingParams, serverCallContext); const nonBlockingTask = nonBlockingResult as Task; assert.equal(nonBlockingTask.kind, 'task', 'Result should be a task'); assert.equal(nonBlockingTask.status.state, 'failed', 'Task status should be failed'); @@ -437,7 +455,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { bus.finished(); }); - const firstResult = await handler.sendMessage(firstParams); + const firstResult = await handler.sendMessage(firstParams, serverCallContext); const firstTask = firstResult as Task; // Check the first result is a task with `input-required` status @@ -527,7 +545,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { bus.finished(); }); - const secondResult = await handler.sendMessage(secondParams); + const secondResult = await handler.sendMessage(secondParams, serverCallContext); const secondTask = secondResult as Task; // Check the second result is a task with `completed` status @@ -638,7 +656,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { bus.finished(); }); - const firstResult = await handler.sendMessage(firstParams); + const firstResult = await handler.sendMessage(firstParams, serverCallContext); const firstTask = firstResult as Task; // Check the first result is a task with `input-required` status @@ -731,7 +749,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { bus.finished(); }); - const secondResult = await handler.sendMessage(secondParams); + const secondResult = await handler.sendMessage(secondParams, serverCallContext); // Check the second result is a task with `completed` status const secondTask = secondResult as Task; @@ -741,7 +759,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { await clock.runAllAsync(); // give time to the second task to publish all the updates - const finalTask = await mockTaskStore.load(taskId); + const finalTask = await mockTaskStore.load(taskId, serverCallContext); // Check the history assert.equal(finalTask.status.state, 'completed'); @@ -822,7 +840,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { bus.finished(); }); - const eventGenerator = handler.sendMessageStream(params); + const eventGenerator = handler.sendMessageStream(params, serverCallContext); const events = []; for await (const event of eventGenerator) { events.push(event); @@ -846,14 +864,14 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { status: { state: state as TaskState }, kind: 'task', }; - await mockTaskStore.save(fakeTask); + await mockTaskStore.save(fakeTask, serverCallContext); const params: MessageSendParams = { message: { ...createTestMessage('msg-1', 'test'), taskId: taskId }, }; try { - await handler.sendMessage(params); + await handler.sendMessage(params, serverCallContext); assert.fail(`Should have thrown for state: ${state}`); } catch (error: any) { expect(error.code).to.equal(-32600); // Invalid Request @@ -872,13 +890,13 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { status: { state: 'completed' }, kind: 'task', }; - await mockTaskStore.save(fakeTask); + await mockTaskStore.save(fakeTask, serverCallContext); const params: MessageSendParams = { message: { ...createTestMessage('msg-1', 'test'), taskId: taskId }, }; - const generator = handler.sendMessageStream(params); + const generator = handler.sendMessageStream(params, serverCallContext); try { await generator.next(); @@ -915,7 +933,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { bus.finished(); }); - const eventGenerator = handler.sendMessageStream(params); + const eventGenerator = handler.sendMessageStream(params, serverCallContext); const events = []; for await (const event of eventGenerator) { events.push(event); @@ -965,7 +983,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { bus.finished(); }); - const stream1_generator = handler.sendMessageStream(params); + const stream1_generator = handler.sendMessageStream(params, serverCallContext); const stream1_iterator = stream1_generator[Symbol.asyncIterator](); const firstEventResult = await stream1_iterator.next(); @@ -976,7 +994,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { const secondEvent = secondEventResult.value as TaskStatusUpdateEvent; assert.equal(secondEvent.taskId, taskId, 'Should get the task status update event second'); - const stream2_generator = handler.resubscribe({ id: taskId }); + const stream2_generator = handler.resubscribe({ id: taskId }, serverCallContext); const results1: any[] = [firstEvent, secondEvent]; const results2: any[] = []; @@ -1015,9 +1033,9 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { kind: 'task', history: [], }; - await mockTaskStore.save(fakeTask); + await mockTaskStore.save(fakeTask, serverCallContext); - const result = await handler.getTask({ id: 'task-exist' }); + const result = await handler.getTask({ id: 'task-exist' }, serverCallContext); assert.deepEqual(result, fakeTask); }); @@ -1029,7 +1047,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { status: { state: 'working' }, kind: 'task', }; - await mockTaskStore.save(fakeTask); + await mockTaskStore.save(fakeTask, serverCallContext); const pushConfig: PushNotificationConfig = { id: 'config-1', @@ -1041,7 +1059,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { taskId, pushNotificationConfig: pushConfig, }; - const setResponse = await handler.setTaskPushNotificationConfig(setParams); + const setResponse = await handler.setTaskPushNotificationConfig(setParams, serverCallContext); assert.deepEqual( setResponse.pushNotificationConfig, pushConfig, @@ -1052,7 +1070,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { id: taskId, pushNotificationConfigId: 'config-1', }; - const getResponse = await handler.getTaskPushNotificationConfig(getParams); + const getResponse = await handler.getTaskPushNotificationConfig(getParams, serverCallContext); assert.deepEqual( getResponse.pushNotificationConfig, pushConfig, @@ -1062,70 +1080,94 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { it('set/getTaskPushNotificationConfig: should save and retrieve config by task ID for backward compatibility', async () => { const taskId = 'task-push-compat'; - await mockTaskStore.save({ - id: taskId, - contextId: 'ctx-compat', - status: { state: 'working' }, - kind: 'task', - }); + await mockTaskStore.save( + { + id: taskId, + contextId: 'ctx-compat', + status: { state: 'working' }, + kind: 'task', + }, + serverCallContext + ); // Config ID defaults to task ID const pushConfig: PushNotificationConfig = { url: 'https://example.com/notify-compat', }; - await handler.setTaskPushNotificationConfig({ - taskId, - pushNotificationConfig: pushConfig, - }); + await handler.setTaskPushNotificationConfig( + { + taskId, + pushNotificationConfig: pushConfig, + }, + serverCallContext + ); - const getResponse = await handler.getTaskPushNotificationConfig({ - id: taskId, - }); + const getResponse = await handler.getTaskPushNotificationConfig( + { + id: taskId, + }, + serverCallContext + ); expect(getResponse.pushNotificationConfig.id).to.equal(taskId); expect(getResponse.pushNotificationConfig.url).to.equal(pushConfig.url); }); it('setTaskPushNotificationConfig: should overwrite an existing config with the same ID', async () => { const taskId = 'task-overwrite'; - await mockTaskStore.save({ - id: taskId, - contextId: 'ctx-overwrite', - status: { state: 'working' }, - kind: 'task', - }); + await mockTaskStore.save( + { + id: taskId, + contextId: 'ctx-overwrite', + status: { state: 'working' }, + kind: 'task', + }, + serverCallContext + ); const initialConfig: PushNotificationConfig = { id: 'config-same', url: 'https://initial.url', }; - await handler.setTaskPushNotificationConfig({ - taskId, - pushNotificationConfig: initialConfig, - }); + await handler.setTaskPushNotificationConfig( + { + taskId, + pushNotificationConfig: initialConfig, + }, + serverCallContext + ); const newConfig: PushNotificationConfig = { id: 'config-same', url: 'https://new.url', }; - await handler.setTaskPushNotificationConfig({ - taskId, - pushNotificationConfig: newConfig, - }); + await handler.setTaskPushNotificationConfig( + { + taskId, + pushNotificationConfig: newConfig, + }, + serverCallContext + ); - const configs = await handler.listTaskPushNotificationConfigs({ - id: taskId, - }); + const configs = await handler.listTaskPushNotificationConfigs( + { + id: taskId, + }, + serverCallContext + ); expect(configs).to.have.lengthOf(1); expect(configs[0].pushNotificationConfig.url).to.equal('https://new.url'); }); it('listTaskPushNotificationConfigs: should return all configs for a task', async () => { const taskId = 'task-list-configs'; - await mockTaskStore.save({ - id: taskId, - contextId: 'ctx-list', - status: { state: 'working' }, - kind: 'task', - }); + await mockTaskStore.save( + { + id: taskId, + contextId: 'ctx-list', + status: { state: 'working' }, + kind: 'task', + }, + serverCallContext + ); const config1: PushNotificationConfig = { id: 'cfg1', url: 'https://url1.com', @@ -1134,17 +1176,26 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { id: 'cfg2', url: 'https://url2.com', }; - await handler.setTaskPushNotificationConfig({ - taskId, - pushNotificationConfig: config1, - }); - await handler.setTaskPushNotificationConfig({ - taskId, - pushNotificationConfig: config2, - }); + await handler.setTaskPushNotificationConfig( + { + taskId, + pushNotificationConfig: config1, + }, + serverCallContext + ); + await handler.setTaskPushNotificationConfig( + { + taskId, + pushNotificationConfig: config2, + }, + serverCallContext + ); const listParams: ListTaskPushNotificationConfigParams = { id: taskId }; - const listResponse = await handler.listTaskPushNotificationConfigs(listParams); + const listResponse = await handler.listTaskPushNotificationConfigs( + listParams, + serverCallContext + ); expect(listResponse).to.be.an('array').with.lengthOf(2); assert.deepInclude(listResponse, { @@ -1159,12 +1210,15 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { it('deleteTaskPushNotificationConfig: should remove a specific config', async () => { const taskId = 'task-delete-config'; - await mockTaskStore.save({ - id: taskId, - contextId: 'ctx-delete', - status: { state: 'working' }, - kind: 'task', - }); + await mockTaskStore.save( + { + id: taskId, + contextId: 'ctx-delete', + status: { state: 'working' }, + kind: 'task', + }, + serverCallContext + ); const config1: PushNotificationConfig = { id: 'cfg-del-1', url: 'https://url1.com', @@ -1173,53 +1227,74 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { id: 'cfg-del-2', url: 'https://url2.com', }; - await handler.setTaskPushNotificationConfig({ - taskId, - pushNotificationConfig: config1, - }); - await handler.setTaskPushNotificationConfig({ - taskId, - pushNotificationConfig: config2, - }); + await handler.setTaskPushNotificationConfig( + { + taskId, + pushNotificationConfig: config1, + }, + serverCallContext + ); + await handler.setTaskPushNotificationConfig( + { + taskId, + pushNotificationConfig: config2, + }, + serverCallContext + ); const deleteParams: DeleteTaskPushNotificationConfigParams = { id: taskId, pushNotificationConfigId: 'cfg-del-1', }; - await handler.deleteTaskPushNotificationConfig(deleteParams); + await handler.deleteTaskPushNotificationConfig(deleteParams, serverCallContext); - const remainingConfigs = await handler.listTaskPushNotificationConfigs({ - id: taskId, - }); + const remainingConfigs = await handler.listTaskPushNotificationConfigs( + { + id: taskId, + }, + serverCallContext + ); expect(remainingConfigs).to.have.lengthOf(1); expect(remainingConfigs[0].pushNotificationConfig.id).to.equal('cfg-del-2'); }); it('deleteTaskPushNotificationConfig: should remove the whole entry if last config is deleted', async () => { const taskId = 'task-delete-last-config'; - await mockTaskStore.save({ - id: taskId, - contextId: 'ctx-delete-last', - status: { state: 'working' }, - kind: 'task', - }); + await mockTaskStore.save( + { + id: taskId, + contextId: 'ctx-delete-last', + status: { state: 'working' }, + kind: 'task', + }, + serverCallContext + ); const config: PushNotificationConfig = { id: 'cfg-last', url: 'https://last.com', }; - await handler.setTaskPushNotificationConfig({ - taskId, - pushNotificationConfig: config, - }); + await handler.setTaskPushNotificationConfig( + { + taskId, + pushNotificationConfig: config, + }, + serverCallContext + ); - await handler.deleteTaskPushNotificationConfig({ - id: taskId, - pushNotificationConfigId: 'cfg-last', - }); + await handler.deleteTaskPushNotificationConfig( + { + id: taskId, + pushNotificationConfigId: 'cfg-last', + }, + serverCallContext + ); - const configs = await handler.listTaskPushNotificationConfigs({ - id: taskId, - }); + const configs = await handler.listTaskPushNotificationConfigs( + { + id: taskId, + }, + serverCallContext + ); expect(configs).to.be.an('array').with.lengthOf(0); }); @@ -1256,7 +1331,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { fakeTaskExecute(ctx, bus); }); - await handler.sendMessage(params); + await handler.sendMessage(params, serverCallContext); const expectedTask: Task = { id: taskId, @@ -1331,7 +1406,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { fakeTaskExecute(ctx, bus); }); - const eventGenerator = handler.sendMessageStream(params); + const eventGenerator = handler.sendMessageStream(params, serverCallContext); const events = []; for await (const event of eventGenerator) { events.push(event); @@ -1410,7 +1485,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { for (const method of methodsToTest) { try { - await (handler as any)[method.name](method.params); + await (handler as any)[method.name](method.params, serverCallContext); assert.fail(`Method ${method.name} should have thrown for non-existent task.`); } catch (error: any) { expect(error).to.be.instanceOf(A2AError); @@ -1432,12 +1507,15 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { ); const taskId = 'task-unsupported'; - await mockTaskStore.save({ - id: taskId, - contextId: 'ctx-unsupported', - status: { state: 'working' }, - kind: 'task', - }); + await mockTaskStore.save( + { + id: taskId, + contextId: 'ctx-unsupported', + status: { state: 'working' }, + kind: 'task', + }, + serverCallContext + ); const config: PushNotificationConfig = { id: 'cfg-u', url: 'https://u.com', @@ -1484,7 +1562,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { const streamParams: MessageSendParams = { message: createTestMessage('msg-9', 'Start and cancel'), }; - const streamGenerator = handler.sendMessageStream(streamParams); + const streamGenerator = handler.sendMessageStream(streamParams, serverCallContext); const streamEvents: any[] = []; (async () => { @@ -1501,14 +1579,14 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { const taskId = createdTask.id; // Now, issue the cancel request - const cancelResponse = await handler.cancelTask({ id: taskId }); + const cancelResponse = await handler.cancelTask({ id: taskId }, serverCallContext); // Let the executor's loop run to completion to detect the cancellation await clock.runAllAsync(); assert.isTrue(cancellableExecutor.cancelTaskSpy.calledOnceWith(taskId, sinon.match.any)); - const finalTask = await handler.getTask({ id: taskId }); + const finalTask = await handler.getTask({ id: taskId }, serverCallContext); assert.equal(finalTask.status.state, 'canceled'); assert.equal(cancelResponse.status.state, 'canceled'); @@ -1529,7 +1607,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { const streamParams: MessageSendParams = { message: createTestMessage('msg-9', 'Start and cancel'), }; - const streamGenerator = handler.sendMessageStream(streamParams); + const streamGenerator = handler.sendMessageStream(streamParams, serverCallContext); const streamEvents: any[] = []; (async () => { @@ -1548,7 +1626,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { let cancelResponse: Task; let thrownError: any; try { - cancelResponse = await handler.cancelTask({ id: taskId }); + cancelResponse = await handler.cancelTask({ id: taskId }, serverCallContext); } catch (error: any) { thrownError = error; } finally { @@ -1570,10 +1648,10 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { status: { state: 'completed' }, kind: 'task', }; - await mockTaskStore.save(fakeTask); + await mockTaskStore.save(fakeTask, serverCallContext); try { - await handler.cancelTask({ id: taskId }); + await handler.cancelTask({ id: taskId }, serverCallContext); assert.fail('Should have thrown a TaskNotCancelableError'); } catch (error: any) { assert.equal(error.code, -32002); @@ -1603,19 +1681,22 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { }); bus.finished(); }); - await handler.sendMessage(params); + await handler.sendMessage(params, serverCallContext); expect(capturedContextId).to.equal('incoming-ctx-id'); }); it('should use contextId from task if not present in incomingMessage (contextId assignment logic)', async () => { const taskId = 'task-ctx-id'; const taskContextId = 'task-context-id'; - await mockTaskStore.save({ - id: taskId, - contextId: taskContextId, - status: { state: 'working' }, - kind: 'task', - }); + await mockTaskStore.save( + { + id: taskId, + contextId: taskContextId, + status: { state: 'working' }, + kind: 'task', + }, + serverCallContext + ); const params: MessageSendParams = { message: { messageId: 'msg-ctx2', @@ -1636,7 +1717,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { }); bus.finished(); }); - await handler.sendMessage(params); + await handler.sendMessage(params, serverCallContext); expect(capturedContextId).to.equal(taskContextId); }); @@ -1660,7 +1741,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { }); bus.finished(); }); - await handler.sendMessage(params); + await handler.sendMessage(params, serverCallContext); expect(capturedContextId).to.be.a('string').and.not.empty; }); @@ -1711,7 +1792,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { status: { state: 'submitted' as TaskState }, kind: 'task', }; - await mockTaskStore.save(fakeTask); + await mockTaskStore.save(fakeTask, serverCallContext); await handler.sendMessage( params, new ServerCallContext(