diff --git a/packages/core/src/agent/Dispatcher.ts b/packages/core/src/agent/Dispatcher.ts index 2c3c86769e..b42f9aa6ca 100644 --- a/packages/core/src/agent/Dispatcher.ts +++ b/packages/core/src/agent/Dispatcher.ts @@ -1,16 +1,19 @@ import type { AgentMessage } from './AgentMessage' import type { AgentMessageProcessedEvent } from './Events' +import type { MessageHandlerMiddleware } from './MessageHandlerMiddleware' import type { InboundMessageContext } from './models/InboundMessageContext' import { InjectionSymbols } from '../constants' -import { CredoError } from '../error/CredoError' +import { CredoError } from '../error' import { Logger } from '../logger' +import { ProblemReportError, ProblemReportReason } from '../modules/problem-reports' import { injectable, inject } from '../plugins' -import { parseMessageType } from '../utils/messageType' +import { canHandleMessageType, parseMessageType } from '../utils/messageType' import { ProblemReportMessage } from './../modules/problem-reports/messages/ProblemReportMessage' import { EventEmitter } from './EventEmitter' import { AgentEventTypes } from './Events' +import { MessageHandlerMiddlewareRunner } from './MessageHandlerMiddleware' import { MessageHandlerRegistry } from './MessageHandlerRegistry' import { MessageSender } from './MessageSender' import { OutboundMessageContext } from './models' @@ -34,22 +37,58 @@ class Dispatcher { this.logger = logger } + private defaultHandlerMiddleware: MessageHandlerMiddleware = async (inboundMessageContext, next) => { + let messageHandler = inboundMessageContext.messageHandler + + if (!messageHandler && inboundMessageContext.agentContext.dependencyManager.fallbackMessageHandler) { + messageHandler = { + supportedMessages: [], + handle: inboundMessageContext.agentContext.dependencyManager.fallbackMessageHandler, + } + } + + if (!messageHandler) { + throw new ProblemReportError( + `Error handling message ${inboundMessageContext.message.id} with type ${inboundMessageContext.message.type}. The message type is not supported`, + { + problemCode: ProblemReportReason.MessageParseFailure, + } + ) + } + + const outboundMessage = await messageHandler.handle(inboundMessageContext) + if (outboundMessage) { + inboundMessageContext.setResponseMessage(outboundMessage) + } + + await next() + } + public async dispatch(messageContext: InboundMessageContext): Promise { const { agentContext, connection, senderKey, recipientKey, message } = messageContext - const messageHandler = this.messageHandlerRegistry.getHandlerForMessageType(message.type) - if (!messageHandler) { - throw new CredoError(`No handler for message type "${message.type}" found`) + // Set default handler if available, middleware can still override the message handler + const messageHandler = this.messageHandlerRegistry.getHandlerForMessageType(message.type) + if (messageHandler) { + messageContext.setMessageHandler(messageHandler) } - let outboundMessage: OutboundMessageContext | void + let outboundMessage: OutboundMessageContext | undefined try { - outboundMessage = await messageHandler.handle(messageContext) + const middlewares = [...agentContext.dependencyManager.messageHandlerMiddlewares, this.defaultHandlerMiddleware] + await MessageHandlerMiddlewareRunner.run(middlewares, messageContext) + + outboundMessage = messageContext.responseMessage } catch (error) { const problemReportMessage = error.problemReport if (problemReportMessage instanceof ProblemReportMessage && messageContext.connection) { + const messageType = parseMessageType(messageContext.message.type) + if (canHandleMessageType(ProblemReportMessage, messageType)) { + throw new CredoError(`Not sending problem report in response to problem report: ${message}`) + } + const { protocolUri: problemReportProtocolUri } = parseMessageType(problemReportMessage.type) const { protocolUri: inboundProtocolUri } = parseMessageType(messageContext.message.type) @@ -91,6 +130,7 @@ class Dispatcher { await this.messageSender.sendMessage(outboundMessage) } + // Emit event that allows to hook into received messages this.eventEmitter.emit(agentContext, { type: AgentEventTypes.AgentMessageProcessed, diff --git a/packages/core/src/agent/MessageHandlerMiddleware.ts b/packages/core/src/agent/MessageHandlerMiddleware.ts new file mode 100644 index 0000000000..9eff446bbd --- /dev/null +++ b/packages/core/src/agent/MessageHandlerMiddleware.ts @@ -0,0 +1,26 @@ +import type { InboundMessageContext } from './models/InboundMessageContext' + +export interface MessageHandlerMiddleware { + (inboundMessageContext: InboundMessageContext, next: () => Promise): Promise +} + +export class MessageHandlerMiddlewareRunner { + public static async run(middlewares: MessageHandlerMiddleware[], inboundMessageContext: InboundMessageContext) { + const compose = (middlewares: MessageHandlerMiddleware[]) => { + return async function (inboundMessageContext: InboundMessageContext) { + let index = -1 + async function dispatch(i: number): Promise { + if (i <= index) throw new Error('next() called multiple times') + index = i + const fn = middlewares[i] + if (!fn) return + await fn(inboundMessageContext, () => dispatch(i + 1)) + } + await dispatch(0) + } + } + + const composed = compose(middlewares) + await composed(inboundMessageContext) + } +} diff --git a/packages/core/src/agent/MessageReceiver.ts b/packages/core/src/agent/MessageReceiver.ts index 19fca86a8b..02f94d1e7b 100644 --- a/packages/core/src/agent/MessageReceiver.ts +++ b/packages/core/src/agent/MessageReceiver.ts @@ -1,4 +1,3 @@ -import type { AgentMessage } from './AgentMessage' import type { DecryptedMessageContext } from './EnvelopeService' import type { TransportSession } from './TransportService' import type { AgentContext } from './context' @@ -16,6 +15,7 @@ import { isValidJweStructure } from '../utils/JWE' import { JsonTransformer } from '../utils/JsonTransformer' import { canHandleMessageType, parseMessageType, replaceLegacyDidSovPrefixOnMessage } from '../utils/messageType' +import { AgentMessage } from './AgentMessage' import { Dispatcher } from './Dispatcher' import { EnvelopeService } from './EnvelopeService' import { MessageHandlerRegistry } from './MessageHandlerRegistry' @@ -250,13 +250,7 @@ export class MessageReceiver { replaceLegacyDidSovPrefixOnMessage(message) const messageType = message['@type'] - const MessageClass = this.messageHandlerRegistry.getMessageClassForMessageType(messageType) - - if (!MessageClass) { - throw new ProblemReportError(`No message class found for message type "${messageType}"`, { - problemCode: ProblemReportReason.MessageParseFailure, - }) - } + const MessageClass = this.messageHandlerRegistry.getMessageClassForMessageType(messageType) ?? AgentMessage // Cast the plain JSON object to specific instance of Message extended from AgentMessage let messageTransformed: AgentMessage diff --git a/packages/core/src/agent/__tests__/Dispatcher.test.ts b/packages/core/src/agent/__tests__/Dispatcher.test.ts index 7bbcb89f95..7f8f2513c5 100644 --- a/packages/core/src/agent/__tests__/Dispatcher.test.ts +++ b/packages/core/src/agent/__tests__/Dispatcher.test.ts @@ -1,3 +1,5 @@ +import type { ConnectionRecord } from '../../modules/connections' + import { Subject } from 'rxjs' import { getAgentConfig, getAgentContext } from '../../../tests/helpers' @@ -7,11 +9,22 @@ import { Dispatcher } from '../Dispatcher' import { EventEmitter } from '../EventEmitter' import { MessageHandlerRegistry } from '../MessageHandlerRegistry' import { MessageSender } from '../MessageSender' +import { getOutboundMessageContext } from '../getOutboundMessageContext' import { InboundMessageContext } from '../models/InboundMessageContext' +jest.mock('../MessageSender') + class CustomProtocolMessage extends AgentMessage { public readonly type = CustomProtocolMessage.type.messageTypeUri public static readonly type = parseMessageType('https://didcomm.org/fake-protocol/1.5/message') + + public constructor(options: { id?: string }) { + super() + + if (options) { + this.id = options.id ?? this.generateId() + } + } } describe('Dispatcher', () => { @@ -29,7 +42,7 @@ describe('Dispatcher', () => { messageHandlerRegistry, agentConfig.logger ) - const customProtocolMessage = new CustomProtocolMessage() + const customProtocolMessage = new CustomProtocolMessage({}) const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext }) const mockHandle = jest.fn() @@ -48,15 +61,215 @@ describe('Dispatcher', () => { new MessageHandlerRegistry(), agentConfig.logger ) - const customProtocolMessage = new CustomProtocolMessage() + const customProtocolMessage = new CustomProtocolMessage({ + id: '55170d10-b91f-4df2-9dcd-6deb4e806c1b', + }) const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext }) const mockHandle = jest.fn() messageHandlerRegistry.registerMessageHandler({ supportedMessages: [], handle: mockHandle }) await expect(dispatcher.dispatch(inboundMessageContext)).rejects.toThrow( - 'No handler for message type "https://didcomm.org/fake-protocol/1.5/message" found' + 'Error handling message 55170d10-b91f-4df2-9dcd-6deb4e806c1b with type https://didcomm.org/fake-protocol/1.5/message. The message type is not supported' + ) + }) + + it('calls the middleware in the order they are registered', async () => { + const agentContext = getAgentContext() + + const dispatcher = new Dispatcher( + new MessageSenderMock(), + eventEmitter, + new MessageHandlerRegistry(), + agentConfig.logger + ) + + const customProtocolMessage = new CustomProtocolMessage({ + id: '55170d10-b91f-4df2-9dcd-6deb4e806c1b', + }) + const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext }) + + const firstMiddleware = jest.fn().mockImplementation(async (_, next) => next()) + const secondMiddleware = jest.fn() + agentContext.dependencyManager.registerMessageHandlerMiddleware(firstMiddleware) + agentContext.dependencyManager.registerMessageHandlerMiddleware(secondMiddleware) + + await dispatcher.dispatch(inboundMessageContext) + + expect(firstMiddleware).toHaveBeenCalled() + expect(secondMiddleware).toHaveBeenCalled() + + // Verify the order of calls + const firstMiddlewareCallOrder = firstMiddleware.mock.invocationCallOrder[0] + const secondMiddlewareCallOrder = secondMiddleware.mock.invocationCallOrder[0] + expect(firstMiddlewareCallOrder).toBeLessThan(secondMiddlewareCallOrder) + }) + + it('calls the middleware in the order they are registered', async () => { + const agentContext = getAgentContext() + + const dispatcher = new Dispatcher( + new MessageSenderMock(), + eventEmitter, + new MessageHandlerRegistry(), + agentConfig.logger + ) + + const customProtocolMessage = new CustomProtocolMessage({ + id: '55170d10-b91f-4df2-9dcd-6deb4e806c1b', + }) + const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext }) + + const firstMiddleware = jest.fn().mockImplementation(async (_, next) => next()) + const secondMiddleware = jest.fn() + agentContext.dependencyManager.registerMessageHandlerMiddleware(firstMiddleware) + agentContext.dependencyManager.registerMessageHandlerMiddleware(secondMiddleware) + + await dispatcher.dispatch(inboundMessageContext) + + expect(firstMiddleware).toHaveBeenCalled() + expect(secondMiddleware).toHaveBeenCalled() + + // Verify the order of calls + const firstMiddlewareCallOrder = firstMiddleware.mock.invocationCallOrder[0] + const secondMiddlewareCallOrder = secondMiddleware.mock.invocationCallOrder[0] + expect(firstMiddlewareCallOrder).toBeLessThan(secondMiddlewareCallOrder) + }) + + it('correctly calls the fallback message handler if no message handler is registered for the message type', async () => { + const agentContext = getAgentContext() + + const dispatcher = new Dispatcher( + new MessageSenderMock(), + eventEmitter, + new MessageHandlerRegistry(), + agentConfig.logger + ) + + const customProtocolMessage = new CustomProtocolMessage({ + id: '55170d10-b91f-4df2-9dcd-6deb4e806c1b', + }) + const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext }) + + const fallbackMessageHandler = jest.fn() + agentContext.dependencyManager.setFallbackMessageHandler(fallbackMessageHandler) + + await dispatcher.dispatch(inboundMessageContext) + + expect(fallbackMessageHandler).toHaveBeenCalled() + }) + + it('will not call the message handler if the middleware does not call next (intercept incoming message handling)', async () => { + const messageHandlerRegistry = new MessageHandlerRegistry() + const agentContext = getAgentContext() + + const dispatcher = new Dispatcher( + new MessageSenderMock(), + eventEmitter, + messageHandlerRegistry, + agentConfig.logger ) + + const customProtocolMessage = new CustomProtocolMessage({ + id: '55170d10-b91f-4df2-9dcd-6deb4e806c1b', + }) + const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext }) + + const mockHandle = jest.fn() + messageHandlerRegistry.registerMessageHandler({ supportedMessages: [CustomProtocolMessage], handle: mockHandle }) + + const middleware = jest.fn() + agentContext.dependencyManager.registerMessageHandlerMiddleware(middleware) + await dispatcher.dispatch(inboundMessageContext) + expect(mockHandle).not.toHaveBeenCalled() + + // Not it should call it, as the middleware calls next + middleware.mockImplementationOnce((_, next) => next()) + await dispatcher.dispatch(inboundMessageContext) + expect(mockHandle).toHaveBeenCalled() + }) + + it('calls the message handler set by the middleware', async () => { + const agentContext = getAgentContext() + + const dispatcher = new Dispatcher( + new MessageSenderMock(), + eventEmitter, + new MessageHandlerRegistry(), + agentConfig.logger + ) + + const customProtocolMessage = new CustomProtocolMessage({ + id: '55170d10-b91f-4df2-9dcd-6deb4e806c1b', + }) + const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext }) + + const handle = jest.fn() + const middleware = jest + .fn() + .mockImplementationOnce(async (inboundMessageContext: InboundMessageContext, next) => { + inboundMessageContext.messageHandler = { + supportedMessages: [], + handle: handle, + } + + await next() + }) + + agentContext.dependencyManager.registerMessageHandlerMiddleware(middleware) + await dispatcher.dispatch(inboundMessageContext) + expect(middleware).toHaveBeenCalled() + expect(handle).toHaveBeenCalled() + }) + + it('sends the response message set by the middleware', async () => { + const agentContext = getAgentContext({ + agentConfig, + }) + const messageSenderMock = new MessageSenderMock() + + const dispatcher = new Dispatcher( + messageSenderMock, + eventEmitter, + new MessageHandlerRegistry(), + agentConfig.logger + ) + + const connectionMock = jest.fn() as unknown as ConnectionRecord + + const customProtocolMessage = new CustomProtocolMessage({ + id: '55170d10-b91f-4df2-9dcd-6deb4e806c1b', + }) + const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { + agentContext, + connection: connectionMock, + }) + + const middleware = jest.fn().mockImplementationOnce(async (inboundMessageContext: InboundMessageContext) => { + // We do not call next + inboundMessageContext.responseMessage = await getOutboundMessageContext(inboundMessageContext.agentContext, { + message: new CustomProtocolMessage({ + id: 'static-id', + }), + connectionRecord: inboundMessageContext.connection, + }) + }) + + agentContext.dependencyManager.registerMessageHandlerMiddleware(middleware) + await dispatcher.dispatch(inboundMessageContext) + expect(middleware).toHaveBeenCalled() + expect(messageSenderMock.sendMessage).toHaveBeenCalledWith({ + inboundMessageContext, + agentContext, + associatedRecord: undefined, + connection: connectionMock, + message: new CustomProtocolMessage({ + id: 'static-id', + }), + outOfBand: undefined, + serviceParams: undefined, + sessionId: undefined, + }) }) }) }) diff --git a/packages/core/src/agent/models/InboundMessageContext.ts b/packages/core/src/agent/models/InboundMessageContext.ts index 03bfef54f3..886210f5f5 100644 --- a/packages/core/src/agent/models/InboundMessageContext.ts +++ b/packages/core/src/agent/models/InboundMessageContext.ts @@ -1,6 +1,8 @@ +import type { OutboundMessageContext } from './OutboundMessageContext' import type { Key } from '../../crypto' import type { ConnectionRecord } from '../../modules/connections' import type { AgentMessage } from '../AgentMessage' +import type { MessageHandler } from '../MessageHandler' import type { AgentContext } from '../context' import { CredoError } from '../../error' @@ -15,14 +17,18 @@ export interface MessageContextParams { } export class InboundMessageContext { - public message: T public connection?: ConnectionRecord public sessionId?: string public senderKey?: Key public recipientKey?: Key public receivedAt: Date + public readonly agentContext: AgentContext + public message: T + public messageHandler?: MessageHandler + public responseMessage?: OutboundMessageContext + public constructor(message: T, context: MessageContextParams) { this.message = message this.recipientKey = context.recipientKey @@ -33,6 +39,14 @@ export class InboundMessageContext { this.receivedAt = context.receivedAt ?? new Date() } + public setMessageHandler(messageHandler: MessageHandler) { + this.messageHandler = messageHandler + } + + public setResponseMessage(outboundMessageContext: OutboundMessageContext) { + this.responseMessage = outboundMessageContext + } + /** * Assert the inbound message has a ready connection associated with it. * diff --git a/packages/core/src/modules/connections/ConnectionsApi.ts b/packages/core/src/modules/connections/ConnectionsApi.ts index 634fb51d16..c90b9561eb 100644 --- a/packages/core/src/modules/connections/ConnectionsApi.ts +++ b/packages/core/src/modules/connections/ConnectionsApi.ts @@ -277,11 +277,11 @@ export class ConnectionsApi { * @param connectionId the id of the connection for which to accept the response * @param responseRequested do we want a response to our ping * @param withReturnRouting do we want a response at the time of posting - * @returns TurstPingMessage + * @returns TrustPingMessage */ public async sendPing( connectionId: string, - { responseRequested = true, withReturnRouting = undefined }: SendPingOptions + { responseRequested = true, withReturnRouting = undefined }: SendPingOptions = {} ) { const connection = await this.getById(connectionId) diff --git a/packages/core/src/plugins/DependencyManager.ts b/packages/core/src/plugins/DependencyManager.ts index 166bc1380c..844a1dc480 100644 --- a/packages/core/src/plugins/DependencyManager.ts +++ b/packages/core/src/plugins/DependencyManager.ts @@ -1,5 +1,6 @@ import type { ModulesMap } from '../agent/AgentModules' import type { MessageHandler } from '../agent/MessageHandler' +import type { MessageHandlerMiddleware } from '../agent/MessageHandlerMiddleware' import type { Constructor } from '../utils/mixins' import type { DependencyContainer } from 'tsyringe' @@ -15,6 +16,9 @@ export class DependencyManager { public readonly container: DependencyContainer public readonly registeredModules: ModulesMap + public readonly messageHandlerMiddlewares: MessageHandlerMiddleware[] = [] + private _fallbackMessageHandler?: MessageHandler['handle'] + public constructor( container: DependencyContainer = rootContainer.createChildContainer(), registeredModules: ModulesMap = {} @@ -49,6 +53,22 @@ export class DependencyManager { } } + public registerMessageHandlerMiddleware(messageHandlerMiddleware: MessageHandlerMiddleware) { + this.messageHandlerMiddlewares.push(messageHandlerMiddleware) + } + + public get fallbackMessageHandler() { + return this._fallbackMessageHandler + } + + /** + * Sets the fallback message handler, the message handler that will be called if no handler + * is registered for an incoming message type. + */ + public setFallbackMessageHandler(fallbackMessageHandler: MessageHandler['handle']) { + this._fallbackMessageHandler = fallbackMessageHandler + } + public registerSingleton(from: InjectionToken, to: InjectionToken): void public registerSingleton(token: Constructor): void // eslint-disable-next-line @typescript-eslint/no-explicit-any diff --git a/packages/core/tests/middleware.test.ts b/packages/core/tests/middleware.test.ts new file mode 100644 index 0000000000..cf76ca8031 --- /dev/null +++ b/packages/core/tests/middleware.test.ts @@ -0,0 +1,121 @@ +import type { SubjectMessage } from '../../../tests/transport/SubjectInboundTransport' +import type { ConnectionRecord, InboundMessageContext } from '../src' + +import { Subject } from 'rxjs' + +import { SubjectInboundTransport } from '../../../tests/transport/SubjectInboundTransport' +import { SubjectOutboundTransport } from '../../../tests/transport/SubjectOutboundTransport' +import { + TrustPingResponseMessage, + BasicMessage, + getOutboundMessageContext, + MessageSender, + AgentMessage, + JsonTransformer, + Agent, +} from '../src' + +import { + getInMemoryAgentOptions, + makeConnection, + waitForAgentMessageProcessedEvent, + waitForBasicMessage, +} from './helpers' + +const faberConfig = getInMemoryAgentOptions('Faber Message Handler Middleware', { + endpoints: ['rxjs:faber'], +}) + +const aliceConfig = getInMemoryAgentOptions('Alice Message Handler Middleware', { + endpoints: ['rxjs:alice'], +}) + +describe('Message Handler Middleware E2E', () => { + let faberAgent: Agent + let aliceAgent: Agent + let faberConnection: ConnectionRecord + // eslint-disable-next-line @typescript-eslint/no-unused-vars + let aliceConnection: ConnectionRecord + + beforeEach(async () => { + const faberMessages = new Subject() + const aliceMessages = new Subject() + const subjectMap = { + 'rxjs:faber': faberMessages, + 'rxjs:alice': aliceMessages, + } + + faberAgent = new Agent(faberConfig) + faberAgent.registerInboundTransport(new SubjectInboundTransport(faberMessages)) + faberAgent.registerOutboundTransport(new SubjectOutboundTransport(subjectMap)) + await faberAgent.initialize() + + aliceAgent = new Agent(aliceConfig) + aliceAgent.registerInboundTransport(new SubjectInboundTransport(aliceMessages)) + aliceAgent.registerOutboundTransport(new SubjectOutboundTransport(subjectMap)) + await aliceAgent.initialize() + ;[aliceConnection, faberConnection] = await makeConnection(aliceAgent, faberAgent) + }) + + afterEach(async () => { + await faberAgent.shutdown() + await faberAgent.wallet.delete() + await aliceAgent.shutdown() + await aliceAgent.wallet.delete() + }) + + test('Correctly calls the fallback message handler if no message handler is defined', async () => { + // Fallback message handler + aliceAgent.dependencyManager.setFallbackMessageHandler((messageContext) => { + return getOutboundMessageContext(messageContext.agentContext, { + connectionRecord: messageContext.connection, + message: new BasicMessage({ + content: "Hey there, I'm not sure I understand the message you sent to me", + }), + }) + }) + + const message = JsonTransformer.fromJSON( + { + '@type': 'https://credo.js.org/custom-messaging/1.0/say-hello', + '@id': 'b630b69a-2b82-4764-87ba-56aa2febfb97', + }, + AgentMessage + ) + + // Send a custom message + const messageSender = faberAgent.dependencyManager.resolve(MessageSender) + const outboundMessageContext = await getOutboundMessageContext(faberAgent.context, { + connectionRecord: faberConnection, + message, + }) + await messageSender.sendMessage(outboundMessageContext) + + // Expect the basic message sent by the fallback message handler + await waitForBasicMessage(faberAgent, { + content: "Hey there, I'm not sure I understand the message you sent to me", + }) + }) + + test('Correctly calls the registered message handler middleware', async () => { + aliceAgent.dependencyManager.registerMessageHandlerMiddleware( + async (inboundMessageContext: InboundMessageContext, next) => { + await next() + + if (inboundMessageContext.responseMessage) { + inboundMessageContext.responseMessage.message.setTiming({ + outTime: new Date('2021-01-01'), + }) + } + } + ) + + await faberAgent.connections.sendPing(faberConnection.id, {}) + const receiveMessage = await waitForAgentMessageProcessedEvent(faberAgent, { + messageType: TrustPingResponseMessage.type.messageTypeUri, + }) + + // Should have sent the message with the timing added in the middleware + expect(receiveMessage.timing?.outTime).toEqual(new Date('2021-01-01')) + }) +})