From f5f9bb6b550baf957710bf0f92bec8f8833629e2 Mon Sep 17 00:00:00 2001 From: tchapacan Date: Mon, 15 Dec 2025 01:15:17 +0100 Subject: [PATCH] feat: add rest client --- src/client/factory.ts | 3 +- src/client/index.ts | 5 + src/client/transports/json_rpc_transport.ts | 61 +--- src/client/transports/rest_transport.ts | 329 +++++++++++++++++ src/errors.ts | 15 + .../transports/rest/rest_transport_handler.ts | 18 +- src/sse_utils.ts | 82 ++++- ...ort.test.ts => json_rpc_transport.spec.ts} | 0 test/client/transports/rest_transport.spec.ts | 334 ++++++++++++++++++ test/client/util.ts | 57 +++ test/e2e.spec.ts | 233 ++++++------ test/sse_utils.spec.ts | 177 ++++++++++ 12 files changed, 1136 insertions(+), 178 deletions(-) create mode 100644 src/client/transports/rest_transport.ts rename test/client/transports/{json_rpc_transport.test.ts => json_rpc_transport.spec.ts} (100%) create mode 100644 test/client/transports/rest_transport.spec.ts create mode 100644 test/sse_utils.spec.ts diff --git a/src/client/factory.ts b/src/client/factory.ts index cb201f87..70943aec 100644 --- a/src/client/factory.ts +++ b/src/client/factory.ts @@ -3,6 +3,7 @@ import { AgentCard } from '../types.js'; import { AgentCardResolver } from './card-resolver.js'; import { Client, ClientConfig } from './multitransport-client.js'; import { JsonRpcTransportFactory } from './transports/json_rpc_transport.js'; +import { RestTransportFactory } from './transports/rest_transport.js'; import { TransportFactory } from './transports/transport.js'; export interface ClientFactoryOptions { @@ -34,7 +35,7 @@ export const ClientFactoryOptions = { * SDK default options for {@link ClientFactory}. */ default: { - transports: [new JsonRpcTransportFactory()], + transports: [new JsonRpcTransportFactory(), new RestTransportFactory()], } as Readonly, /** diff --git a/src/client/index.ts b/src/client/index.ts index 4a27baf6..e3113b64 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -18,6 +18,11 @@ export { JsonRpcTransportFactory, type JsonRpcTransportOptions, } from './transports/json_rpc_transport.js'; +export { + RestTransport, + RestTransportFactory, + type RestTransportOptions, +} from './transports/rest_transport.js'; export type { CallInterceptor, BeforeArgs, diff --git a/src/client/transports/json_rpc_transport.ts b/src/client/transports/json_rpc_transport.ts index 8b426faf..35aa18af 100644 --- a/src/client/transports/json_rpc_transport.ts +++ b/src/client/transports/json_rpc_transport.ts @@ -32,6 +32,7 @@ import { } from '../../types.js'; import { A2AStreamEventData, SendMessageResult } from '../client.js'; import { RequestOptions } from '../multitransport-client.js'; +import { parseSseStream } from '../../sse_utils.js'; import { Transport, TransportFactory } from './transport.js'; export interface JsonRpcTransportOptions { @@ -303,62 +304,8 @@ export class JsonRpcTransport implements Transport { ); } - yield* this._parseA2ASseStream(response, clientRequestId); - } - - private async *_parseA2ASseStream( - response: Response, - originalRequestId: number | string | null - ): AsyncGenerator { - if (!response.body) { - throw new Error('SSE response body is undefined. Cannot read stream.'); - } - const reader = response.body.pipeThrough(new TextDecoderStream()).getReader(); - let buffer = ''; - let eventDataBuffer = ''; - - try { - while (true) { - const { done, value } = await reader.read(); - if (done) { - if (eventDataBuffer.trim()) { - const result = this._processSseEventData( - eventDataBuffer, - originalRequestId - ); - yield result; - } - break; - } - - buffer += value; - let lineEndIndex; - while ((lineEndIndex = buffer.indexOf('\n')) >= 0) { - const line = buffer.substring(0, lineEndIndex).trim(); - buffer = buffer.substring(lineEndIndex + 1); - - if (line === '') { - if (eventDataBuffer) { - const result = this._processSseEventData( - eventDataBuffer, - originalRequestId - ); - yield result; - eventDataBuffer = ''; - } - } else if (line.startsWith('data:')) { - eventDataBuffer += line.substring(5).trimStart() + '\n'; - } - } - } - } catch (error) { - console.error( - 'Error reading or parsing SSE stream:', - (error instanceof Error && error.message) || 'Error unknown' - ); - throw error; - } finally { - reader.releaseLock(); + for await (const event of parseSseStream(response)) { + yield this._processSseEventData(event.data, clientRequestId); } } @@ -370,7 +317,7 @@ export class JsonRpcTransport implements Transport { throw new Error('Attempted to process empty SSE event data.'); } try { - const sseJsonRpcResponse = JSON.parse(jsonData.replace(/\n$/, '')); + const sseJsonRpcResponse = JSON.parse(jsonData); const a2aStreamResponse: JSONRPCResponse = sseJsonRpcResponse as JSONRPCResponse; if (a2aStreamResponse.id !== originalRequestId) { diff --git a/src/client/transports/rest_transport.ts b/src/client/transports/rest_transport.ts new file mode 100644 index 00000000..81aac324 --- /dev/null +++ b/src/client/transports/rest_transport.ts @@ -0,0 +1,329 @@ +import { TransportProtocolName } from '../../core.js'; +import { + A2A_ERROR_CODE, + AuthenticatedExtendedCardNotConfiguredError, + ContentTypeNotSupportedError, + InvalidAgentResponseError, + PushNotificationNotSupportedError, + TaskNotFoundError, + TaskNotCancelableError, + UnsupportedOperationError, +} from '../../errors.js'; +import { + AgentCard, + DeleteTaskPushNotificationConfigParams, + GetTaskPushNotificationConfigParams, + ListTaskPushNotificationConfigParams, + MessageSendParams, + TaskPushNotificationConfig, + TaskIdParams, + TaskQueryParams, + Task, +} from '../../types.js'; +import { A2AStreamEventData, SendMessageResult } from '../client.js'; +import { RequestOptions } from '../multitransport-client.js'; +import { parseSseStream } from '../../sse_utils.js'; +import { Transport, TransportFactory } from './transport.js'; + +export interface RestTransportOptions { + endpoint: string; + fetchImpl?: typeof fetch; +} + +interface RestErrorResponse { + code: number; + message: string; + data?: Record; +} + +export class RestTransport implements Transport { + private readonly customFetchImpl?: typeof fetch; + private readonly endpoint: string; + + constructor(options: RestTransportOptions) { + this.endpoint = options.endpoint.replace(/\/+$/, ''); + this.customFetchImpl = options.fetchImpl; + } + + async getExtendedAgentCard(options?: RequestOptions): Promise { + return this._sendRequest('GET', '/v1/card', undefined, options); + } + + async sendMessage( + params: MessageSendParams, + options?: RequestOptions + ): Promise { + return this._sendRequest('POST', '/v1/message:send', params, options); + } + + async *sendMessageStream( + params: MessageSendParams, + options?: RequestOptions + ): AsyncGenerator { + yield* this._sendStreamingRequest('/v1/message:stream', params, options); + } + + async setTaskPushNotificationConfig( + params: TaskPushNotificationConfig, + options?: RequestOptions + ): Promise { + return this._sendRequest( + 'POST', + `/v1/tasks/${encodeURIComponent(params.taskId)}/pushNotificationConfigs`, + { + pushNotificationConfig: params.pushNotificationConfig, + }, + options + ); + } + + async getTaskPushNotificationConfig( + params: GetTaskPushNotificationConfigParams, + options?: RequestOptions + ): Promise { + const { pushNotificationConfigId } = params; + if (!pushNotificationConfigId) { + throw new Error( + 'pushNotificationConfigId is required for getTaskPushNotificationConfig with REST transport.' + ); + } + return this._sendRequest( + 'GET', + `/v1/tasks/${encodeURIComponent(params.id)}/pushNotificationConfigs/${encodeURIComponent(pushNotificationConfigId)}`, + undefined, + options + ); + } + + async listTaskPushNotificationConfig( + params: ListTaskPushNotificationConfigParams, + options?: RequestOptions + ): Promise { + return this._sendRequest( + 'GET', + `/v1/tasks/${encodeURIComponent(params.id)}/pushNotificationConfigs`, + undefined, + options + ); + } + + async deleteTaskPushNotificationConfig( + params: DeleteTaskPushNotificationConfigParams, + options?: RequestOptions + ): Promise { + await this._sendRequest( + 'DELETE', + `/v1/tasks/${encodeURIComponent(params.id)}/pushNotificationConfigs/${encodeURIComponent(params.pushNotificationConfigId)}`, + undefined, + options + ); + } + + async getTask(params: TaskQueryParams, options?: RequestOptions): Promise { + const queryParams = new URLSearchParams(); + if (params.historyLength !== undefined) { + queryParams.set('historyLength', String(params.historyLength)); + } + const queryString = queryParams.toString(); + const path = `/v1/tasks/${encodeURIComponent(params.id)}${queryString ? `?${queryString}` : ''}`; + return this._sendRequest('GET', path, undefined, options); + } + + async cancelTask(params: TaskIdParams, options?: RequestOptions): Promise { + return this._sendRequest( + 'POST', + `/v1/tasks/${encodeURIComponent(params.id)}:cancel`, + undefined, + options + ); + } + + async *resubscribeTask( + params: TaskIdParams, + options?: RequestOptions + ): AsyncGenerator { + yield* this._sendStreamingRequest( + `/v1/tasks/${encodeURIComponent(params.id)}:subscribe`, + undefined, + options + ); + } + + private _fetch(...args: Parameters): ReturnType { + if (this.customFetchImpl) { + return this.customFetchImpl(...args); + } + if (typeof fetch === 'function') { + return fetch(...args); + } + throw new Error( + 'A `fetch` implementation was not provided and is not available in the global scope. ' + + 'Please provide a `fetchImpl` in the RestTransportOptions.' + ); + } + + private _buildHeaders( + options: RequestOptions | undefined, + acceptHeader: string = 'application/json' + ): HeadersInit { + return { + ...options?.serviceParameters, + 'Content-Type': 'application/json', + Accept: acceptHeader, + }; + } + + private async _sendRequest( + method: 'GET' | 'POST' | 'DELETE', + path: string, + body: unknown | undefined, + options: RequestOptions | undefined + ): Promise { + const url = `${this.endpoint}${path}`; + const requestInit: RequestInit = { + method, + headers: this._buildHeaders(options), + signal: options?.signal, + }; + + if (body !== undefined && method !== 'GET') { + requestInit.body = JSON.stringify(body); + } + + const response = await this._fetch(url, requestInit); + + if (!response.ok) { + await this._handleErrorResponse(response, path); + } + + if (response.status === 204) { + return undefined as TResponse; + } + + const result = await response.json(); + return result as TResponse; + } + + private async _handleErrorResponse(response: Response, path: string): Promise { + let errorBodyText = '(empty or non-JSON response)'; + let errorBody: RestErrorResponse | undefined; + + try { + errorBodyText = await response.text(); + if (errorBodyText) { + errorBody = JSON.parse(errorBodyText); + } + } catch (e) { + throw new Error( + `HTTP error for ${path}! Status: ${response.status} ${response.statusText}. Response: ${errorBodyText}`, + { cause: e } + ); + } + + if (errorBody && typeof errorBody.code === 'number') { + throw RestTransport.mapToError(errorBody); + } + + throw new Error( + `HTTP error for ${path}! Status: ${response.status} ${response.statusText}. Response: ${errorBodyText}` + ); + } + + private async *_sendStreamingRequest( + path: string, + body: unknown | undefined, + options?: RequestOptions + ): AsyncGenerator { + const url = `${this.endpoint}${path}`; + const requestInit: RequestInit = { + method: 'POST', + headers: this._buildHeaders(options, 'text/event-stream'), + signal: options?.signal, + }; + + if (body !== undefined) { + requestInit.body = JSON.stringify(body); + } + + const response = await this._fetch(url, requestInit); + + if (!response.ok) { + await this._handleErrorResponse(response, path); + } + + const contentType = response.headers.get('Content-Type'); + if (!contentType?.startsWith('text/event-stream')) { + throw new Error( + `Invalid response Content-Type for SSE stream. Expected 'text/event-stream', got '${contentType}'.` + ); + } + + for await (const event of parseSseStream(response)) { + if (event.type === 'error') { + const errorData = JSON.parse(event.data); + throw RestTransport.mapToError(errorData); + } + yield this._processSseEventData(event.data); + } + } + + private _processSseEventData(jsonData: string): A2AStreamEventData { + if (!jsonData.trim()) { + throw new Error('Attempted to process empty SSE event data.'); + } + + try { + const data = JSON.parse(jsonData); + return data as A2AStreamEventData; + } catch (e) { + console.error('Failed to parse SSE event data:', jsonData, e); + throw new Error( + `Failed to parse SSE event data: "${jsonData.substring(0, 100)}...". Original error: ${(e instanceof Error && e.message) || 'Unknown error'}` + ); + } + } + + private static mapToError(error: RestErrorResponse): Error { + switch (error.code) { + case A2A_ERROR_CODE.TASK_NOT_FOUND: + return new TaskNotFoundError(error.message); + case A2A_ERROR_CODE.TASK_NOT_CANCELABLE: + return new TaskNotCancelableError(error.message); + case A2A_ERROR_CODE.PUSH_NOTIFICATION_NOT_SUPPORTED: + return new PushNotificationNotSupportedError(error.message); + case A2A_ERROR_CODE.UNSUPPORTED_OPERATION: + return new UnsupportedOperationError(error.message); + case A2A_ERROR_CODE.CONTENT_TYPE_NOT_SUPPORTED: + return new ContentTypeNotSupportedError(error.message); + case A2A_ERROR_CODE.INVALID_AGENT_RESPONSE: + return new InvalidAgentResponseError(error.message); + case A2A_ERROR_CODE.AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED: + return new AuthenticatedExtendedCardNotConfiguredError(error.message); + default: + return new Error( + `REST error: ${error.message} (Code: ${error.code})${error.data ? ` Data: ${JSON.stringify(error.data)}` : ''}` + ); + } + } +} + +export interface RestTransportFactoryOptions { + fetchImpl?: typeof fetch; +} + +export class RestTransportFactory implements TransportFactory { + public static readonly name: TransportProtocolName = 'HTTP+JSON'; + + constructor(private readonly options?: RestTransportFactoryOptions) {} + + get protocolName(): string { + return RestTransportFactory.name; + } + + async create(url: string, _agentCard: AgentCard): Promise { + return new RestTransport({ + endpoint: url, + fetchImpl: this.options?.fetchImpl, + }); + } +} diff --git a/src/errors.ts b/src/errors.ts index 8b646233..9974977a 100644 --- a/src/errors.ts +++ b/src/errors.ts @@ -1,5 +1,20 @@ // Transport-agnostic errors according to https://a2a-protocol.org/latest/specification/#82-a2a-specific-errors; +export const A2A_ERROR_CODE = { + PARSE_ERROR: -32700, + INVALID_REQUEST: -32600, + METHOD_NOT_FOUND: -32601, + INVALID_PARAMS: -32602, + INTERNAL_ERROR: -32603, + TASK_NOT_FOUND: -32001, + TASK_NOT_CANCELABLE: -32002, + PUSH_NOTIFICATION_NOT_SUPPORTED: -32003, + UNSUPPORTED_OPERATION: -32004, + CONTENT_TYPE_NOT_SUPPORTED: -32005, + INVALID_AGENT_RESPONSE: -32006, + AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED: -32007, +} as const; + export class TaskNotFoundError extends Error { constructor(message?: string) { super(message ?? 'Task not found'); diff --git a/src/server/transports/rest/rest_transport_handler.ts b/src/server/transports/rest/rest_transport_handler.ts index 52b68c64..a471ef4d 100644 --- a/src/server/transports/rest/rest_transport_handler.ts +++ b/src/server/transports/rest/rest_transport_handler.ts @@ -32,6 +32,7 @@ import { TaskPushNotificationConfigInput, FileInput, } from './rest_types.js'; +import { A2A_ERROR_CODE } from '../../../errors.js'; // ============================================================================ // HTTP Status Codes and Error Mapping @@ -53,21 +54,6 @@ export const HTTP_STATUS = { NOT_IMPLEMENTED: 501, } as const; -/** - * A2A error codes mapped to JSON-RPC and protocol-specific errors. - */ -const A2A_ERROR_CODE = { - PARSE_ERROR: -32700, - INVALID_REQUEST: -32600, - METHOD_NOT_FOUND: -32601, - INVALID_PARAMS: -32602, - TASK_NOT_FOUND: -32001, - TASK_NOT_CANCELABLE: -32002, - PUSH_NOTIFICATION_NOT_SUPPORTED: -32003, - UNSUPPORTED_OPERATION: -32004, - UNAUTHORIZED: -32005, -} as const; - /** * Maps A2A error codes to appropriate HTTP status codes. * @@ -92,8 +78,6 @@ export function mapErrorToStatus(errorCode: number): number { case A2A_ERROR_CODE.PUSH_NOTIFICATION_NOT_SUPPORTED: case A2A_ERROR_CODE.UNSUPPORTED_OPERATION: return HTTP_STATUS.BAD_REQUEST; - case A2A_ERROR_CODE.UNAUTHORIZED: - return HTTP_STATUS.UNAUTHORIZED; default: return HTTP_STATUS.INTERNAL_SERVER_ERROR; } diff --git a/src/sse_utils.ts b/src/sse_utils.ts index 3f5bf7a4..c79d0f83 100644 --- a/src/sse_utils.ts +++ b/src/sse_utils.ts @@ -1,6 +1,6 @@ /** * Shared Server-Sent Events (SSE) utilities for both JSON-RPC and REST transports. - * This module provides common SSE formatting functions and headers. + * This module provides common SSE formatting and parsing functions. */ // ============================================================================ @@ -19,7 +19,19 @@ export const SSE_HEADERS = { } as const; // ============================================================================ -// SSE Event Formatting +// SSE Event Types +// ============================================================================ + +/** + * Represents a parsed SSE event with type and data. + */ +export interface SseEvent { + type: string; + data: string; +} + +// ============================================================================ +// SSE Event Formatting (Server-side) // ============================================================================ /** @@ -59,3 +71,69 @@ export function formatSSEEvent(event: unknown): string { export function formatSSEErrorEvent(error: unknown): string { return `event: error\ndata: ${JSON.stringify(error)}\n\n`; } + +// ============================================================================ +// SSE Event Parsing (Client-side) +// ============================================================================ + +/** + * Parses a Server-Sent Events (SSE) stream from a Response object. + * Yields parsed SSE events as they arrive. + * + * This parser expects well-formed SSE events with single-line JSON data, + * matching the format produced by formatSSEEvent and formatSSEErrorEvent. + * + * @param response - The fetch Response containing an SSE stream + * @yields SseEvent objects with type and data fields + * + * @example + * ```ts + * for await (const event of parseSseStream(response)) { + * if (event.type === 'error') { + * handleError(JSON.parse(event.data)); + * } else { + * handleData(JSON.parse(event.data)); + * } + * } + * ``` + */ +export async function* parseSseStream( + response: Response +): AsyncGenerator { + if (!response.body) { + throw new Error('SSE response body is undefined. Cannot read stream.'); + } + + let buffer = ''; + let eventType = 'message'; + let eventData = ''; + + for await (const value of response.body.pipeThrough(new TextDecoderStream())) { + buffer += value; + let lineEndIndex: number; + + while ((lineEndIndex = buffer.indexOf('\n')) >= 0) { + const line = buffer.substring(0, lineEndIndex).trim(); + buffer = buffer.substring(lineEndIndex + 1); + + if (line === '') { + // Empty line signals end of event + if (eventData) { + yield { type: eventType, data: eventData }; + eventData = ''; + eventType = 'message'; + } + } else if (line.startsWith('event:')) { + eventType = line.substring('event:'.length).trim(); + } else if (line.startsWith('data:')) { + // Expect well-formed JSON on a single data line + eventData = line.substring('data:'.length).trim(); + } + } + } + + // Yield any pending event at stream end + if (eventData) { + yield { type: eventType, data: eventData }; + } +} diff --git a/test/client/transports/json_rpc_transport.test.ts b/test/client/transports/json_rpc_transport.spec.ts similarity index 100% rename from test/client/transports/json_rpc_transport.test.ts rename to test/client/transports/json_rpc_transport.spec.ts diff --git a/test/client/transports/rest_transport.spec.ts b/test/client/transports/rest_transport.spec.ts new file mode 100644 index 00000000..867bd2b3 --- /dev/null +++ b/test/client/transports/rest_transport.spec.ts @@ -0,0 +1,334 @@ +import { + RestTransport, + RestTransportFactory, +} from '../../../src/client/transports/rest_transport.js'; +import sinon from 'sinon'; +import { describe, it, beforeEach, afterEach, expect } from 'vitest'; +import { TaskPushNotificationConfig } from '../../../src/types.js'; +import { RequestOptions } from '../../../src/client/multitransport-client.js'; +import { HTTP_EXTENSION_HEADER } from '../../../src/constants.js'; +import { ServiceParameters, withA2AExtensions } from '../../../src/client/service-parameters.js'; +import { + TaskNotFoundError, + TaskNotCancelableError, + PushNotificationNotSupportedError, +} from '../../../src/errors.js'; +import { + createMessageParams, + createMockAgentCard, + createMockMessage, + createMockTask, + createRestResponse, + createRestErrorResponse, +} from '../util.js'; + +describe('RestTransport', () => { + let transport: RestTransport; + let mockFetch: sinon.SinonStubbedFunction; + const endpoint = 'https://test.endpoint/api'; + + beforeEach(() => { + mockFetch = sinon.stub(); + transport = new RestTransport({ + endpoint, + fetchImpl: mockFetch, + }); + }); + + afterEach(() => { + sinon.restore(); + }); + + describe('constructor', () => { + it('should trim trailing slashes from endpoint', async () => { + const trailingSlashTransport = new RestTransport({ + endpoint: 'https://example.com/a2a/rest/', + fetchImpl: mockFetch, + }); + const mockResponse = createMockMessage(); + mockFetch.resolves(createRestResponse(mockResponse)); + + await trailingSlashTransport.sendMessage(createMessageParams()); + + const [url] = mockFetch.firstCall.args; + expect(url).to.equal('https://example.com/a2a/rest/v1/message:send'); + }); + + it('should trim multiple trailing slashes from endpoint', async () => { + const trailingSlashTransport = new RestTransport({ + endpoint: 'https://example.com/a2a/rest///', + fetchImpl: mockFetch, + }); + const mockResponse = createMockMessage(); + mockFetch.resolves(createRestResponse(mockResponse)); + + await trailingSlashTransport.sendMessage(createMessageParams()); + + const [url] = mockFetch.firstCall.args; + expect(url).to.equal('https://example.com/a2a/rest/v1/message:send'); + }); + }); + + describe('sendMessage', () => { + it('should send message successfully', async () => { + const messageParams = createMessageParams(); + const mockResponse = createMockMessage(); + + mockFetch.resolves(createRestResponse(mockResponse)); + + const result = await transport.sendMessage(messageParams); + + expect(result).to.deep.equal(mockResponse); + expect(mockFetch.calledOnce).to.be.true; + + const [url, options] = mockFetch.firstCall.args; + expect(url).to.equal(`${endpoint}/v1/message:send`); + expect(options?.method).to.equal('POST'); + expect((options?.headers as Record)['Content-Type']).to.equal( + 'application/json' + ); + }); + + it('should correctly add the extension headers', async () => { + const messageParams = createMessageParams(); + const expectedExtensions = 'extension1,extension2'; + const serviceParameters = ServiceParameters.create(withA2AExtensions(expectedExtensions)); + const options: RequestOptions = { serviceParameters }; + + mockFetch.resolves(createRestResponse(createMockMessage())); + + await transport.sendMessage(messageParams, options); + + const fetchArgs = mockFetch.firstCall.args[1]; + const headers = fetchArgs?.headers as Record; + expect(headers[HTTP_EXTENSION_HEADER]).to.equal(expectedExtensions); + }); + + it('should throw TaskNotFoundError on -32001', async () => { + const messageParams = createMessageParams(); + mockFetch.resolves(createRestErrorResponse(-32001, 'Task not found', 404)); + + await expect(transport.sendMessage(messageParams)).rejects.toThrow(TaskNotFoundError); + }); + }); + + describe('getTask', () => { + it('should get task successfully', async () => { + const taskId = 'task-123'; + const mockTask = createMockTask(taskId); + + mockFetch.resolves(createRestResponse(mockTask)); + + const result = await transport.getTask({ id: taskId }); + + expect(result).to.deep.equal(mockTask); + expect(mockFetch.calledOnce).to.be.true; + + const [url, options] = mockFetch.firstCall.args; + expect(url).to.equal(`${endpoint}/v1/tasks/${taskId}`); + expect(options?.method).to.equal('GET'); + }); + + it('should pass historyLength as query parameter', async () => { + const taskId = 'task-123'; + const historyLength = 10; + const mockTask = createMockTask(taskId); + + mockFetch.resolves(createRestResponse(mockTask)); + + const result = await transport.getTask({ id: taskId, historyLength }); + + expect(result).to.deep.equal(mockTask); + expect(mockFetch.calledOnce).to.be.true; + + const [url] = mockFetch.firstCall.args; + expect(url).to.equal(`${endpoint}/v1/tasks/${taskId}?historyLength=${historyLength}`); + }); + + it('should throw TaskNotFoundError when task does not exist', async () => { + mockFetch.resolves(createRestErrorResponse(-32001, 'Task not found', 404)); + + await expect(transport.getTask({ id: 'nonexistent' })).rejects.toThrow(TaskNotFoundError); + }); + }); + + describe('cancelTask', () => { + it('should cancel task successfully', async () => { + const taskId = 'task-123'; + const mockTask = createMockTask(taskId, 'canceled'); + + mockFetch.resolves(createRestResponse(mockTask)); + + const result = await transport.cancelTask({ id: taskId }); + + expect(result).to.deep.equal(mockTask); + expect(mockFetch.calledOnce).to.be.true; + + const [url, options] = mockFetch.firstCall.args; + expect(url).to.equal(`${endpoint}/v1/tasks/${taskId}:cancel`); + expect(options?.method).to.equal('POST'); + }); + + it('should throw TaskNotCancelableError on -32002', async () => { + mockFetch.resolves(createRestErrorResponse(-32002, 'Task cannot be canceled', 409)); + + await expect(transport.cancelTask({ id: 'task-123' })).rejects.toThrow( + TaskNotCancelableError + ); + }); + }); + + describe('getExtendedAgentCard', () => { + it('should get extended agent card successfully', async () => { + const mockCard = { + name: 'Test Agent', + url: endpoint, + version: '1.0.0', + protocolVersion: '0.3.0', + }; + + mockFetch.resolves(createRestResponse(mockCard)); + + const result = await transport.getExtendedAgentCard(); + + expect(result).to.deep.equal(mockCard); + expect(mockFetch.calledOnce).to.be.true; + + const [url, options] = mockFetch.firstCall.args; + expect(url).to.equal(`${endpoint}/v1/card`); + expect(options?.method).to.equal('GET'); + }); + }); + + describe('Push Notification Config', () => { + const taskId = 'task-123'; + const configId = 'config-456'; + const mockConfig: TaskPushNotificationConfig = { + taskId, + pushNotificationConfig: { + id: configId, + url: 'https://notify.example.com/webhook', + }, + }; + + describe('setTaskPushNotificationConfig', () => { + it('should set push notification config successfully', async () => { + mockFetch.resolves(createRestResponse(mockConfig)); + + const result = await transport.setTaskPushNotificationConfig(mockConfig); + + expect(result).to.deep.equal(mockConfig); + expect(mockFetch.calledOnce).to.be.true; + + const [url, options] = mockFetch.firstCall.args; + expect(url).to.equal(`${endpoint}/v1/tasks/${taskId}/pushNotificationConfigs`); + expect(options?.method).to.equal('POST'); + }); + + it('should throw PushNotificationNotSupportedError on -32003', async () => { + mockFetch.resolves( + createRestErrorResponse(-32003, 'Push notifications not supported', 400) + ); + + await expect(transport.setTaskPushNotificationConfig(mockConfig)).rejects.toThrow( + PushNotificationNotSupportedError + ); + }); + }); + + describe('getTaskPushNotificationConfig', () => { + it('should get push notification config successfully', async () => { + mockFetch.resolves(createRestResponse(mockConfig)); + + const result = await transport.getTaskPushNotificationConfig({ + id: taskId, + pushNotificationConfigId: configId, + }); + + expect(result).to.deep.equal(mockConfig); + expect(mockFetch.calledOnce).to.be.true; + + const [url] = mockFetch.firstCall.args; + expect(url).to.equal(`${endpoint}/v1/tasks/${taskId}/pushNotificationConfigs/${configId}`); + }); + + it('should throw error when pushNotificationConfigId is missing', async () => { + await expect( + transport.getTaskPushNotificationConfig({ + id: taskId, + pushNotificationConfigId: undefined as unknown as string, + }) + ).rejects.toThrow('pushNotificationConfigId is required'); + }); + }); + + describe('listTaskPushNotificationConfig', () => { + it('should list push notification configs successfully', async () => { + const mockConfigs = [ + mockConfig, + { ...mockConfig, pushNotificationConfig: { id: 'config-789' } }, + ]; + mockFetch.resolves(createRestResponse(mockConfigs)); + + const result = await transport.listTaskPushNotificationConfig({ id: taskId }); + + expect(result).to.deep.equal(mockConfigs); + expect(mockFetch.calledOnce).to.be.true; + + const [url, options] = mockFetch.firstCall.args; + expect(url).to.equal(`${endpoint}/v1/tasks/${taskId}/pushNotificationConfigs`); + expect(options?.method).to.equal('GET'); + }); + }); + + describe('deleteTaskPushNotificationConfig', () => { + it('should delete push notification config successfully', async () => { + mockFetch.resolves(new Response(null, { status: 204 })); + + await transport.deleteTaskPushNotificationConfig({ + id: taskId, + pushNotificationConfigId: configId, + }); + + expect(mockFetch.calledOnce).to.be.true; + + const [url, options] = mockFetch.firstCall.args; + expect(url).to.equal(`${endpoint}/v1/tasks/${taskId}/pushNotificationConfigs/${configId}`); + expect(options?.method).to.equal('DELETE'); + }); + }); + }); + + describe('Error Handling', () => { + it('should handle HTTP errors with non-JSON response', async () => { + mockFetch.resolves( + new Response('Internal Server Error', { + status: 500, + headers: { 'Content-Type': 'text/plain' }, + }) + ); + + await expect(transport.getTask({ id: 'task-123' })).rejects.toThrow('HTTP error'); + }); + + it('should handle network errors', async () => { + mockFetch.rejects(new Error('Network error')); + + await expect(transport.getTask({ id: 'task-123' })).rejects.toThrow('Network error'); + }); + }); +}); + +describe('RestTransportFactory', () => { + it('should have correct protocol name', () => { + const factory = new RestTransportFactory(); + expect(factory.protocolName).to.equal('HTTP+JSON'); + }); + + it('should create transport with correct endpoint', async () => { + const factory = new RestTransportFactory(); + const agentCard = createMockAgentCard({ url: 'https://example.com/api' }); + const transport = await factory.create(agentCard.url, agentCard); + expect(transport).to.be.instanceOf(RestTransport); + }); +}); diff --git a/test/client/util.ts b/test/client/util.ts index e1bc04a3..0ae65057 100644 --- a/test/client/util.ts +++ b/test/client/util.ts @@ -358,3 +358,60 @@ export function createMockFetch( return mockFetch as sinon.SinonStub & { capturedAuthHeaders: string[] }; } + +/** + * Creates a REST response (plain JSON, not JSON-RPC wrapped). + * Used for testing REST transport which doesn't use JSON-RPC envelope. + * + * @param data - The data to include in the response + * @param status - HTTP status code (defaults to 200) + * @param headers - Additional headers to include + * @returns A Response object with JSON content + */ +export function createRestResponse( + data: unknown, + status: number = 200, + headers: Record = {} +): Response { + const defaultHeaders = { 'Content-Type': 'application/json' }; + const responseHeaders = { ...defaultHeaders, ...headers }; + return new Response(JSON.stringify(data), { status, headers: responseHeaders }); +} + +/** + * Creates a REST error response with A2A error format. + * + * @param code - A2A error code (e.g., -32001 for TaskNotFound) + * @param message - Error message + * @param status - HTTP status code (defaults to 400) + * @param data - Optional additional error data + * @returns A Response object with error JSON content + */ +export function createRestErrorResponse( + code: number, + message: string, + status: number = 400, + data?: Record +): Response { + const errorBody = { code, message, ...(data && { data }) }; + return new Response(JSON.stringify(errorBody), { + status, + headers: { 'Content-Type': 'application/json' }, + }); +} + +/** + * Creates a mock task response for testing. + * + * @param id - Task ID (defaults to 'task-123') + * @param status - Task status state (defaults to 'completed') + * @returns A mock Task object + */ +export function createMockTask(id: string = 'task-123', status: string = 'completed'): any { + return { + id, + contextId: 'context-123', + status: { state: status }, + kind: 'task', + }; +} diff --git a/test/e2e.spec.ts b/test/e2e.spec.ts index 13fa61f7..cf0ebbdd 100644 --- a/test/e2e.spec.ts +++ b/test/e2e.spec.ts @@ -12,7 +12,8 @@ import { AgentCard, Message } from '../src/types.js'; import sinon from 'sinon'; import { agentCardHandler } from '../src/server/express/agent_card_handler.js'; import { jsonRpcHandler } from '../src/server/express/json_rpc_handler.js'; -import { ClientFactory } from '../src/client/factory.js'; +import { restHandler } from '../src/server/express/rest_handler.js'; +import { ClientFactory, ClientFactoryOptions } from '../src/client/factory.js'; import { Server } from 'http'; import { AddressInfo } from 'net'; import { A2AStreamEventData } from '../src/client/client.js'; @@ -30,112 +31,142 @@ class TestAgentExecutor implements AgentExecutor { cancelTask: (taskId: string, eventBus: ExecutionEventBus) => Promise; } -describe('Client E2E tests', () => { - let app: Express; - let server: Server; - let agentExecutor: TestAgentExecutor; - let clientFactory: ClientFactory; - let agentCard: AgentCard; - - beforeEach(() => { - agentExecutor = new TestAgentExecutor(); - agentCard = { - protocolVersion: '0.3.0', - name: 'Test Agent', - description: 'An agent for testing purposes', - preferredTransport: 'JSONRPC', - url: 'localhost', - version: '1.0.0', - capabilities: { - streaming: true, - pushNotifications: true, - }, - defaultInputModes: ['text/plain'], - defaultOutputModes: ['text/plain'], - skills: [], - }; - const requestHandler = new DefaultRequestHandler( - agentCard, - new InMemoryTaskStore(), - agentExecutor - ); - - app = express(); - - app.use( - '/.well-known/agent-card.json', - agentCardHandler({ agentCardProvider: requestHandler }) - ); - - app.use( - '/a2a/rpc', - jsonRpcHandler({ requestHandler: requestHandler, userBuilder: UserBuilder.noAuthentication }) - ); - - server = app.listen(); - - const address = server.address() as AddressInfo; - agentCard.url = `http://localhost:${address.port}/a2a/rpc`; - - clientFactory = new ClientFactory(); - }); +interface TransportConfig { + name: string; + preferredTransport: string; + serverPath: string; +} - afterEach(() => { - sinon.restore(); - server.close(); - }); +const transportConfigs: TransportConfig[] = [ + { + name: 'JSON-RPC', + preferredTransport: 'JSONRPC', + serverPath: '/a2a/rpc', + }, + { + name: 'REST', + preferredTransport: 'HTTP+JSON', + serverPath: '/a2a/rest', + }, +]; - describe('sendMessage', () => { - it('should send a message to the agent', async () => { - const expected = createTestMessage('1', 'test'); - agentExecutor.events = [expected]; - const client = await clientFactory.createFromAgentCard(agentCard); +describe('Client E2E tests', () => { + const clientFactory = new ClientFactory(ClientFactoryOptions.default); + + transportConfigs.forEach((transportConfig) => { + describe(`[${transportConfig.name}]`, () => { + let app: Express; + let server: Server; + let agentExecutor: TestAgentExecutor; + let agentCard: AgentCard; + + beforeEach(() => { + agentExecutor = new TestAgentExecutor(); + agentCard = { + protocolVersion: '0.3.0', + name: 'Test Agent', + description: 'An agent for testing purposes', + preferredTransport: transportConfig.preferredTransport, + url: 'localhost', + version: '1.0.0', + capabilities: { + streaming: true, + pushNotifications: true, + }, + defaultInputModes: ['text/plain'], + defaultOutputModes: ['text/plain'], + skills: [], + }; + const requestHandler = new DefaultRequestHandler( + agentCard, + new InMemoryTaskStore(), + agentExecutor + ); + + app = express(); + + app.use( + '/.well-known/agent-card.json', + agentCardHandler({ agentCardProvider: requestHandler }) + ); + + app.use( + '/a2a/rpc', + jsonRpcHandler({ + requestHandler: requestHandler, + userBuilder: UserBuilder.noAuthentication, + }) + ); + + app.use( + '/a2a/rest', + restHandler({ requestHandler: requestHandler, userBuilder: UserBuilder.noAuthentication }) + ); + + server = app.listen(); + + const address = server.address() as AddressInfo; + agentCard.url = `http://localhost:${address.port}${transportConfig.serverPath}`; + }); - const actual = await client.sendMessage({ - message: createTestMessage('1', 'test'), + afterEach(() => { + sinon.restore(); + server.close(); }); - expect(actual).to.deep.equal(expected); - }); - }); + describe('sendMessage', () => { + it('should send a message to the agent', async () => { + const expected = createTestMessage('1', 'test'); + agentExecutor.events = [expected]; + const client = await clientFactory.createFromAgentCard(agentCard); + + const actual = await client.sendMessage({ + message: createTestMessage('1', 'test'), + }); - describe('sendMessageStream', () => { - it('should send a message to the agent and read event stream', async () => { - const taskId = '1'; - const contextId = '2'; - const expected: AgentExecutionEvent[] = [ - { - id: taskId, - contextId, - status: { state: 'submitted' }, - kind: 'task', - }, - { - taskId, - contextId, - kind: 'status-update', - status: { state: 'working' }, - final: false, - }, - { - taskId, - contextId, - kind: 'status-update', - status: { state: 'completed' }, - final: true, - }, - ]; - agentExecutor.events = expected; - const client = await clientFactory.createFromAgentCard(agentCard); - - const actual: A2AStreamEventData[] = []; - for await (const message of client.sendMessageStream({ - message: createTestMessage('1', 'test'), - })) { - actual.push(message); - } - - expect(actual).to.deep.equal(expected); + expect(actual).to.deep.equal(expected); + }); + }); + + describe('sendMessageStream', () => { + it('should send a message to the agent and read event stream', async () => { + const taskId = '1'; + const contextId = '2'; + const expected: AgentExecutionEvent[] = [ + { + id: taskId, + contextId, + status: { state: 'submitted' }, + kind: 'task', + }, + { + taskId, + contextId, + kind: 'status-update', + status: { state: 'working' }, + final: false, + }, + { + taskId, + contextId, + kind: 'status-update', + status: { state: 'completed' }, + final: true, + }, + ]; + agentExecutor.events = expected; + const client = await clientFactory.createFromAgentCard(agentCard); + + const actual: A2AStreamEventData[] = []; + for await (const message of client.sendMessageStream({ + message: createTestMessage('1', 'test'), + })) { + actual.push(message); + } + + expect(actual).to.deep.equal(expected); + }); + }); }); }); }); diff --git a/test/sse_utils.spec.ts b/test/sse_utils.spec.ts new file mode 100644 index 00000000..8c96d770 --- /dev/null +++ b/test/sse_utils.spec.ts @@ -0,0 +1,177 @@ +import { describe, it, expect } from 'vitest'; +import { formatSSEEvent, formatSSEErrorEvent, parseSseStream, SseEvent } from '../src/sse_utils.js'; + +/** + * Creates a mock Response object from SSE-formatted strings. + * Used to test that the parser can understand what the formatter produces. + */ +function createMockResponse(sseData: string, chunkSize: number = 2): Response { + const encoder = new TextEncoder(); + const chunks: Uint8Array[] = []; + + for (let i = 0; i < sseData.length; i += chunkSize) { + chunks.push(encoder.encode(sseData.slice(i, i + chunkSize))); + } + + let chunkIndex = 0; + const stream = new ReadableStream({ + pull(controller) { + if (chunkIndex < chunks.length) { + controller.enqueue(chunks[chunkIndex]); + chunkIndex++; + } else { + controller.close(); + } + }, + }); + + return new Response(stream, { + headers: { 'Content-Type': 'text/event-stream' }, + }); +} + +describe('SSE Utils', () => { + describe('formatSSEEvent', () => { + it('should format a data event', () => { + const event = { kind: 'message', text: 'Hello' }; + + const formatted = formatSSEEvent(event); + + expect(formatted).toBe('data: {"kind":"message","text":"Hello"}\n\n'); + }); + + it('should format complex objects', () => { + const event = { nested: { value: 123 }, array: [1, 2, 3] }; + + const formatted = formatSSEEvent(event); + + expect(formatted).toBe('data: {"nested":{"value":123},"array":[1,2,3]}\n\n'); + }); + }); + + describe('formatSSEErrorEvent', () => { + it('should format an error event with event type', () => { + const error = { code: -32603, message: 'Internal error' }; + + const formatted = formatSSEErrorEvent(error); + + expect(formatted).toBe('event: error\ndata: {"code":-32603,"message":"Internal error"}\n\n'); + }); + }); + + describe('parseSseStream', () => { + it('should parse a single data event', async () => { + const sseData = 'data: {"kind":"message"}\n\n'; + const response = createMockResponse(sseData); + + const events: SseEvent[] = []; + for await (const event of parseSseStream(response)) { + events.push(event); + } + + expect(events).toHaveLength(1); + expect(events[0].type).toBe('message'); + expect(events[0].data).toBe('{"kind":"message"}'); + }); + + it('should parse an error event', async () => { + const sseData = 'event: error\ndata: {"code":-32001}\n\n'; + const response = createMockResponse(sseData); + + const events: SseEvent[] = []; + for await (const event of parseSseStream(response)) { + events.push(event); + } + + expect(events).toHaveLength(1); + expect(events[0].type).toBe('error'); + expect(events[0].data).toBe('{"code":-32001}'); + }); + + it('should parse multiple events', async () => { + const sseData = 'data: {"id":1}\n\ndata: {"id":2}\n\n'; + const response = createMockResponse(sseData); + + const events: SseEvent[] = []; + for await (const event of parseSseStream(response)) { + events.push(event); + } + + expect(events).toHaveLength(2); + expect(JSON.parse(events[0].data)).toEqual({ id: 1 }); + expect(JSON.parse(events[1].data)).toEqual({ id: 2 }); + }); + }); + + describe('Symmetry: parser understands formatter output', () => { + it('should parse what formatSSEEvent produces', async () => { + const originalData = { kind: 'task', id: '123', status: 'completed' }; + const formatted = formatSSEEvent(originalData); + const response = createMockResponse(formatted); + + const events: SseEvent[] = []; + for await (const event of parseSseStream(response)) { + events.push(event); + } + + expect(events).toHaveLength(1); + expect(events[0].type).toBe('message'); + expect(JSON.parse(events[0].data)).toEqual(originalData); + }); + + it('should parse what formatSSEErrorEvent produces', async () => { + const originalError = { code: -32001, message: 'Task not found', data: { taskId: 'abc' } }; + const formatted = formatSSEErrorEvent(originalError); + const response = createMockResponse(formatted); + + const events: SseEvent[] = []; + for await (const event of parseSseStream(response)) { + events.push(event); + } + + expect(events).toHaveLength(1); + expect(events[0].type).toBe('error'); + expect(JSON.parse(events[0].data)).toEqual(originalError); + }); + + it('should parse multiple formatted events in sequence', async () => { + const events_to_format = [ + { kind: 'status-update', status: 'working' }, + { kind: 'artifact', data: 'hello' }, + { kind: 'status-update', status: 'completed' }, + ]; + + const formatted = events_to_format.map(formatSSEEvent).join(''); + const response = createMockResponse(formatted); + + const parsedEvents: SseEvent[] = []; + for await (const event of parseSseStream(response)) { + parsedEvents.push(event); + } + + expect(parsedEvents).toHaveLength(3); + for (let i = 0; i < events_to_format.length; i++) { + expect(JSON.parse(parsedEvents[i].data)).toEqual(events_to_format[i]); + } + }); + + it('should parse mixed data and error events', async () => { + const dataEvent = { kind: 'message', text: 'hello' }; + const errorEvent = { code: -32603, message: 'Internal error' }; + + const formatted = formatSSEEvent(dataEvent) + formatSSEErrorEvent(errorEvent); + const response = createMockResponse(formatted); + + const parsedEvents: SseEvent[] = []; + for await (const event of parseSseStream(response)) { + parsedEvents.push(event); + } + + expect(parsedEvents).toHaveLength(2); + expect(parsedEvents[0].type).toBe('message'); + expect(JSON.parse(parsedEvents[0].data)).toEqual(dataEvent); + expect(parsedEvents[1].type).toBe('error'); + expect(JSON.parse(parsedEvents[1].data)).toEqual(errorEvent); + }); + }); +});