Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/server/authentication/user.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
/**
* Represents a user accessing A2A server.
*/
export interface User {
/**
* Indicates whether the user is authenticated.
*/
get isAuthenticated(): boolean;

/**
* A unique name (identifier) for the user.
*/
get userName(): string;
}

/**
* An implementation of {@link User} representing an unauthenticated user.
*/
export class UnauthenticatedUser implements User {
get isAuthenticated(): boolean {
return false;
Expand Down
65 changes: 37 additions & 28 deletions src/server/request_handler/default_request_handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ export class DefaultRequestHandler implements A2ARequestHandler {

// incomingMessage would contain taskId, if a task already exists.
if (incomingMessage.taskId) {
task = await this.taskStore.load(incomingMessage.taskId);
task = await this.taskStore.load(incomingMessage.taskId, context);
if (!task) {
throw A2AError.taskNotFound(incomingMessage.taskId);
}
Expand All @@ -111,15 +111,15 @@ export class DefaultRequestHandler implements A2ARequestHandler {
}
// Add incomingMessage to history and save the task.
task.history = [...(task.history || []), incomingMessage];
await this.taskStore.save(task);
await this.taskStore.save(task, context);
}
// Ensure taskId is present
const taskId = incomingMessage.taskId || uuidv4();

if (incomingMessage.referenceTaskIds && incomingMessage.referenceTaskIds.length > 0) {
referenceTasks = [];
for (const refId of incomingMessage.referenceTaskIds) {
const refTask = await this.taskStore.load(refId);
const refTask = await this.taskStore.load(refId, context);
if (refTask) {
referenceTasks.push(refTask);
} else {
Expand Down Expand Up @@ -157,6 +157,7 @@ export class DefaultRequestHandler implements A2ARequestHandler {
taskId: string,
resultManager: ResultManager,
eventQueue: ExecutionEventQueue,
context: ServerCallContext | undefined,
options?: {
firstResultResolver?: (value: Message | Task | PromiseLike<Message | Task>) => void;
firstResultRejector?: (reason?: unknown) => void;
Expand All @@ -168,7 +169,7 @@ export class DefaultRequestHandler implements A2ARequestHandler {
await resultManager.processEvent(event);

try {
await this._sendPushNotificationIfNeeded(event);
await this._sendPushNotificationIfNeeded(event, context);
} catch (error) {
console.error(`Error sending push notification: ${error}`);
}
Expand Down Expand Up @@ -217,7 +218,7 @@ export class DefaultRequestHandler implements A2ARequestHandler {
// Default to blocking behavior if 'blocking' is not explicitly false.
const isBlocking = params.configuration?.blocking !== false;
// Instantiate ResultManager before creating RequestContext
const resultManager = new ResultManager(this.taskStore);
const resultManager = new ResultManager(this.taskStore, context);
resultManager.setContext(incomingMessage); // Set context for ResultManager

const requestContext = await this._createRequestContext(incomingMessage, context);
Expand Down Expand Up @@ -282,7 +283,7 @@ export class DefaultRequestHandler implements A2ARequestHandler {

if (isBlocking) {
// In blocking mode, wait for the full processing to complete.
await this._processEvents(taskId, resultManager, eventQueue);
await this._processEvents(taskId, resultManager, eventQueue, context);
const finalResult = resultManager.getFinalResult();
if (!finalResult) {
throw A2AError.internalError(
Expand All @@ -294,7 +295,7 @@ export class DefaultRequestHandler implements A2ARequestHandler {
} else {
// In non-blocking mode, return a promise that will be settled by fullProcessing.
return new Promise<Message | Task>((resolve, reject) => {
this._processEvents(taskId, resultManager, eventQueue, {
this._processEvents(taskId, resultManager, eventQueue, context, {
firstResultResolver: resolve,
firstResultRejector: reject,
});
Expand All @@ -318,7 +319,7 @@ export class DefaultRequestHandler implements A2ARequestHandler {
}

// Instantiate ResultManager before creating RequestContext
const resultManager = new ResultManager(this.taskStore);
const resultManager = new ResultManager(this.taskStore, context);
resultManager.setContext(incomingMessage); // Set context for ResultManager

const requestContext = await this._createRequestContext(incomingMessage, context);
Expand Down Expand Up @@ -367,7 +368,7 @@ export class DefaultRequestHandler implements A2ARequestHandler {
try {
for await (const event of eventQueue.events()) {
await resultManager.processEvent(event); // Update store in background
await this._sendPushNotificationIfNeeded(event);
await this._sendPushNotificationIfNeeded(event, context);
yield event; // Stream the event to the client
}
} finally {
Expand All @@ -376,8 +377,8 @@ export class DefaultRequestHandler implements A2ARequestHandler {
}
}

async getTask(params: TaskQueryParams, _context?: ServerCallContext): Promise<Task> {
const task = await this.taskStore.load(params.id);
async getTask(params: TaskQueryParams, context?: ServerCallContext): Promise<Task> {
const task = await this.taskStore.load(params.id, context);
if (!task) {
throw A2AError.taskNotFound(params.id);
}
Expand All @@ -392,8 +393,8 @@ export class DefaultRequestHandler implements A2ARequestHandler {
return task;
}

async cancelTask(params: TaskIdParams, _context?: ServerCallContext): Promise<Task> {
const task = await this.taskStore.load(params.id);
async cancelTask(params: TaskIdParams, context?: ServerCallContext): Promise<Task> {
const task = await this.taskStore.load(params.id, context);
if (!task) {
throw A2AError.taskNotFound(params.id);
}
Expand All @@ -410,7 +411,12 @@ export class DefaultRequestHandler implements A2ARequestHandler {
const eventQueue = new ExecutionEventQueue(eventBus);
await this.agentExecutor.cancelTask(params.id, eventBus);
// Consume all the events until the task reaches a terminal state.
await this._processEvents(params.id, new ResultManager(this.taskStore), eventQueue);
await this._processEvents(
params.id,
new ResultManager(this.taskStore, context),
eventQueue,
context
);
} else {
// Here we are marking task as cancelled. We are not waiting for the executor to actually cancel processing.
task.status = {
Expand All @@ -429,10 +435,10 @@ export class DefaultRequestHandler implements A2ARequestHandler {
// Add cancellation message to history
task.history = [...(task.history || []), task.status.message];

await this.taskStore.save(task);
await this.taskStore.save(task, context);
}

const latestTask = await this.taskStore.load(params.id);
const latestTask = await this.taskStore.load(params.id, context);
if (!latestTask) {
throw A2AError.internalError(`Task ${params.id} not found after cancellation.`);
}
Expand All @@ -444,12 +450,12 @@ export class DefaultRequestHandler implements A2ARequestHandler {

async setTaskPushNotificationConfig(
params: TaskPushNotificationConfig,
_context?: ServerCallContext
context?: ServerCallContext
): Promise<TaskPushNotificationConfig> {
if (!this.agentCard.capabilities.pushNotifications) {
throw A2AError.pushNotificationNotSupported();
}
const task = await this.taskStore.load(params.taskId);
const task = await this.taskStore.load(params.taskId, context);
if (!task) {
throw A2AError.taskNotFound(params.taskId);
}
Expand All @@ -468,12 +474,12 @@ export class DefaultRequestHandler implements A2ARequestHandler {

async getTaskPushNotificationConfig(
params: TaskIdParams | GetTaskPushNotificationConfigParams,
_context?: ServerCallContext
context?: ServerCallContext
): Promise<TaskPushNotificationConfig> {
if (!this.agentCard.capabilities.pushNotifications) {
throw A2AError.pushNotificationNotSupported();
}
const task = await this.taskStore.load(params.id);
const task = await this.taskStore.load(params.id, context);
if (!task) {
throw A2AError.taskNotFound(params.id);
}
Expand Down Expand Up @@ -503,12 +509,12 @@ export class DefaultRequestHandler implements A2ARequestHandler {

async listTaskPushNotificationConfigs(
params: ListTaskPushNotificationConfigParams,
_context?: ServerCallContext
context?: ServerCallContext
): Promise<TaskPushNotificationConfig[]> {
if (!this.agentCard.capabilities.pushNotifications) {
throw A2AError.pushNotificationNotSupported();
}
const task = await this.taskStore.load(params.id);
const task = await this.taskStore.load(params.id, context);
if (!task) {
throw A2AError.taskNotFound(params.id);
}
Expand All @@ -523,12 +529,12 @@ export class DefaultRequestHandler implements A2ARequestHandler {

async deleteTaskPushNotificationConfig(
params: DeleteTaskPushNotificationConfigParams,
_context?: ServerCallContext
context?: ServerCallContext
): Promise<void> {
if (!this.agentCard.capabilities.pushNotifications) {
throw A2AError.pushNotificationNotSupported();
}
const task = await this.taskStore.load(params.id);
const task = await this.taskStore.load(params.id, context);
if (!task) {
throw A2AError.taskNotFound(params.id);
}
Expand All @@ -540,7 +546,7 @@ export class DefaultRequestHandler implements A2ARequestHandler {

async *resubscribe(
params: TaskIdParams,
_context?: ServerCallContext
context?: ServerCallContext
): AsyncGenerator<
| Task // Initial task state
| TaskStatusUpdateEvent
Expand All @@ -552,7 +558,7 @@ export class DefaultRequestHandler implements A2ARequestHandler {
throw A2AError.unsupportedOperation('Streaming (and thus resubscription) is not supported.');
}

const task = await this.taskStore.load(params.id);
const task = await this.taskStore.load(params.id, context);
if (!task) {
throw A2AError.taskNotFound(params.id);
}
Expand Down Expand Up @@ -600,7 +606,10 @@ export class DefaultRequestHandler implements A2ARequestHandler {
}
}

private async _sendPushNotificationIfNeeded(event: AgentExecutionEvent): Promise<void> {
private async _sendPushNotificationIfNeeded(
event: AgentExecutionEvent,
context: ServerCallContext | undefined
): Promise<void> {
if (!this.agentCard.capabilities.pushNotifications) {
return;
}
Expand All @@ -618,7 +627,7 @@ export class DefaultRequestHandler implements A2ARequestHandler {
return;
}

const task = await this.taskStore.load(taskId);
const task = await this.taskStore.load(taskId, context);
if (!task) {
console.error(`Task ${taskId} not found.`);
return;
Expand Down
14 changes: 9 additions & 5 deletions src/server/result_manager.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import { Message, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent } from '../types.js';
import { ServerCallContext } from './context.js';
import { AgentExecutionEvent } from './events/execution_event_bus.js';
import { TaskStore } from './store.js';

export class ResultManager {
private taskStore: TaskStore;
private readonly taskStore: TaskStore;
private readonly serverCallContext?: ServerCallContext;

private currentTask?: Task;
private latestUserMessage?: Message; // To add to history if a new task is created
private finalMessageResult?: Message; // Stores the message if it's the final result

constructor(taskStore: TaskStore) {
constructor(taskStore: TaskStore, serverCallContext?: ServerCallContext) {
this.taskStore = taskStore;
this.serverCallContext = serverCallContext;
}

public setContext(latestUserMessage: Message): void {
Expand Down Expand Up @@ -62,7 +66,7 @@ export class ResultManager {
} else if (!this.currentTask && updateEvent.taskId) {
// Potentially an update for a task we haven't seen the 'task' event for yet,
// or we are rehydrating. Attempt to load.
const loaded = await this.taskStore.load(updateEvent.taskId);
const loaded = await this.taskStore.load(updateEvent.taskId, this.serverCallContext);
if (loaded) {
this.currentTask = loaded;
this.currentTask.status = updateEvent.status;
Expand Down Expand Up @@ -119,7 +123,7 @@ export class ResultManager {
await this.saveCurrentTask();
} else if (!this.currentTask && artifactEvent.taskId) {
// Similar to status update, try to load if task not in memory
const loaded = await this.taskStore.load(artifactEvent.taskId);
const loaded = await this.taskStore.load(artifactEvent.taskId, this.serverCallContext);
if (loaded) {
this.currentTask = loaded;
if (!this.currentTask.artifacts) this.currentTask.artifacts = [];
Expand Down Expand Up @@ -150,7 +154,7 @@ export class ResultManager {

private async saveCurrentTask(): Promise<void> {
if (this.currentTask) {
await this.taskStore.save(this.currentTask);
await this.taskStore.save(this.currentTask, this.serverCallContext);
}
}

Expand Down
9 changes: 6 additions & 3 deletions src/server/store.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Task } from '../types.js';
import { ServerCallContext } from './context.js';

/**
* Simplified interface for task storage providers.
Expand All @@ -8,17 +9,19 @@ export interface TaskStore {
/**
* Saves a task.
* Overwrites existing data if the task ID exists.
* @param data An object containing the task.
* @param task The task to save.
* @param context The context of the current call.
* @returns A promise resolving when the save operation is complete.
*/
save(task: Task): Promise<void>;
save(task: Task, context?: ServerCallContext): Promise<void>;

/**
* Loads a task by task ID.
* @param taskId The ID of the task to load.
* @param context The context of the current call.
* @returns A promise resolving to an object containing the Task, or undefined if not found.
*/
load(taskId: string): Promise<Task | undefined>;
load(taskId: string, context?: ServerCallContext): Promise<Task | undefined>;
}

// ========================
Expand Down
Loading