From 8a7de69e25cac43439eedc024ae5fa620412f0cd Mon Sep 17 00:00:00 2001 From: Timo Glastra Date: Sat, 8 Jun 2024 17:52:39 +0200 Subject: [PATCH 1/3] feat: add message handler middleware and fallback Signed-off-by: Timo Glastra --- packages/core/src/agent/Dispatcher.ts | 49 +++- .../src/agent/MessageHandlerMiddleware.ts | 27 +++ packages/core/src/agent/MessageReceiver.ts | 10 +- .../src/agent/__tests__/Dispatcher.test.ts | 222 +++++++++++++++++- .../src/agent/models/InboundMessageContext.ts | 16 +- .../src/modules/connections/ConnectionsApi.ts | 4 +- .../core/src/plugins/DependencyManager.ts | 20 ++ packages/core/tests/middleware.test.ts | 127 ++++++++++ 8 files changed, 454 insertions(+), 21 deletions(-) create mode 100644 packages/core/src/agent/MessageHandlerMiddleware.ts create mode 100644 packages/core/tests/middleware.test.ts diff --git a/packages/core/src/agent/Dispatcher.ts b/packages/core/src/agent/Dispatcher.ts index 2c3c86769e..4f845f49d5 100644 --- a/packages/core/src/agent/Dispatcher.ts +++ b/packages/core/src/agent/Dispatcher.ts @@ -1,16 +1,19 @@ import type { AgentMessage } from './AgentMessage' import type { AgentMessageProcessedEvent } from './Events' +import type { MessageHandlerMiddleware } from './MessageHandlerMiddleware' import type { InboundMessageContext } from './models/InboundMessageContext' import { InjectionSymbols } from '../constants' -import { CredoError } from '../error/CredoError' +import { CredoError } from '../error' import { Logger } from '../logger' +import { ProblemReportError, ProblemReportReason } from '../modules/problem-reports' import { injectable, inject } from '../plugins' -import { parseMessageType } from '../utils/messageType' +import { canHandleMessageType, parseMessageType } from '../utils/messageType' import { ProblemReportMessage } from './../modules/problem-reports/messages/ProblemReportMessage' import { EventEmitter } from './EventEmitter' import { AgentEventTypes } from './Events' +import { MessageHandlerMiddlewareRunner } from './MessageHandlerMiddleware' import { MessageHandlerRegistry } from './MessageHandlerRegistry' import { MessageSender } from './MessageSender' import { OutboundMessageContext } from './models' @@ -34,22 +37,53 @@ class Dispatcher { this.logger = logger } + private defaultHandlerMiddleware: MessageHandlerMiddleware = async (inboundMessageContext, next) => { + const messageHandler = + inboundMessageContext.messageHandler ?? + inboundMessageContext.agentContext.dependencyManager.fallbackMessageHandler + + if (!messageHandler) { + throw new ProblemReportError( + `Error handling message ${inboundMessageContext.message.id} with type ${inboundMessageContext.message.type}. The message type is not supported`, + { + problemCode: ProblemReportReason.MessageParseFailure, + } + ) + } + + const outboundMessage = await messageHandler.handle(inboundMessageContext) + if (outboundMessage) { + inboundMessageContext.setResponseMessage(outboundMessage) + } + + await next() + } + public async dispatch(messageContext: InboundMessageContext): Promise { const { agentContext, connection, senderKey, recipientKey, message } = messageContext - const messageHandler = this.messageHandlerRegistry.getHandlerForMessageType(message.type) - if (!messageHandler) { - throw new CredoError(`No handler for message type "${message.type}" found`) + // Set default handler if available, middleware can still override the message handler + const messageHandler = this.messageHandlerRegistry.getHandlerForMessageType(message.type) + if (messageHandler) { + messageContext.setMessageHandler(messageHandler) } - let outboundMessage: OutboundMessageContext | void + let outboundMessage: OutboundMessageContext | undefined try { - outboundMessage = await messageHandler.handle(messageContext) + const middlewares = [...agentContext.dependencyManager.messageHandlerMiddlewares, this.defaultHandlerMiddleware] + await MessageHandlerMiddlewareRunner.run(middlewares, messageContext) + + outboundMessage = messageContext.responseMessage } catch (error) { const problemReportMessage = error.problemReport if (problemReportMessage instanceof ProblemReportMessage && messageContext.connection) { + const messageType = parseMessageType(messageContext.message.type) + if (canHandleMessageType(ProblemReportMessage, messageType)) { + throw new CredoError(`Not sending problem report in response to problem report: ${message}`) + } + const { protocolUri: problemReportProtocolUri } = parseMessageType(problemReportMessage.type) const { protocolUri: inboundProtocolUri } = parseMessageType(messageContext.message.type) @@ -91,6 +125,7 @@ class Dispatcher { await this.messageSender.sendMessage(outboundMessage) } + // Emit event that allows to hook into received messages this.eventEmitter.emit(agentContext, { type: AgentEventTypes.AgentMessageProcessed, diff --git a/packages/core/src/agent/MessageHandlerMiddleware.ts b/packages/core/src/agent/MessageHandlerMiddleware.ts new file mode 100644 index 0000000000..8b672cd08a --- /dev/null +++ b/packages/core/src/agent/MessageHandlerMiddleware.ts @@ -0,0 +1,27 @@ +import type { InboundMessageContext } from './models/InboundMessageContext' + +export interface MessageHandlerMiddleware { + (inboundMessageContext: InboundMessageContext, next: () => Promise): Promise +} + +export class MessageHandlerMiddlewareRunner { + public static async run(middlewares: MessageHandlerMiddleware[], inboundMessageContext: InboundMessageContext) { + const compose = (middlewares: MessageHandlerMiddleware[]) => { + return async function (inboundMessageContext: InboundMessageContext, next?: () => Promise) { + let index = -1 + async function dispatch(i: number): Promise { + if (i <= index) throw new Error('next() called multiple times') + index = i + let fn: MessageHandlerMiddleware | undefined = middlewares[i] + if (i === middlewares.length) fn = next + if (!fn) return + await fn(inboundMessageContext, () => dispatch(i + 1)) + } + await dispatch(0) + } + } + + const composed = compose(middlewares) + await composed(inboundMessageContext) + } +} diff --git a/packages/core/src/agent/MessageReceiver.ts b/packages/core/src/agent/MessageReceiver.ts index 19fca86a8b..02f94d1e7b 100644 --- a/packages/core/src/agent/MessageReceiver.ts +++ b/packages/core/src/agent/MessageReceiver.ts @@ -1,4 +1,3 @@ -import type { AgentMessage } from './AgentMessage' import type { DecryptedMessageContext } from './EnvelopeService' import type { TransportSession } from './TransportService' import type { AgentContext } from './context' @@ -16,6 +15,7 @@ import { isValidJweStructure } from '../utils/JWE' import { JsonTransformer } from '../utils/JsonTransformer' import { canHandleMessageType, parseMessageType, replaceLegacyDidSovPrefixOnMessage } from '../utils/messageType' +import { AgentMessage } from './AgentMessage' import { Dispatcher } from './Dispatcher' import { EnvelopeService } from './EnvelopeService' import { MessageHandlerRegistry } from './MessageHandlerRegistry' @@ -250,13 +250,7 @@ export class MessageReceiver { replaceLegacyDidSovPrefixOnMessage(message) const messageType = message['@type'] - const MessageClass = this.messageHandlerRegistry.getMessageClassForMessageType(messageType) - - if (!MessageClass) { - throw new ProblemReportError(`No message class found for message type "${messageType}"`, { - problemCode: ProblemReportReason.MessageParseFailure, - }) - } + const MessageClass = this.messageHandlerRegistry.getMessageClassForMessageType(messageType) ?? AgentMessage // Cast the plain JSON object to specific instance of Message extended from AgentMessage let messageTransformed: AgentMessage diff --git a/packages/core/src/agent/__tests__/Dispatcher.test.ts b/packages/core/src/agent/__tests__/Dispatcher.test.ts index 7bbcb89f95..6257f34515 100644 --- a/packages/core/src/agent/__tests__/Dispatcher.test.ts +++ b/packages/core/src/agent/__tests__/Dispatcher.test.ts @@ -1,3 +1,5 @@ +import type { ConnectionRecord } from '../../modules/connections' + import { Subject } from 'rxjs' import { getAgentConfig, getAgentContext } from '../../../tests/helpers' @@ -7,11 +9,22 @@ import { Dispatcher } from '../Dispatcher' import { EventEmitter } from '../EventEmitter' import { MessageHandlerRegistry } from '../MessageHandlerRegistry' import { MessageSender } from '../MessageSender' +import { getOutboundMessageContext } from '../getOutboundMessageContext' import { InboundMessageContext } from '../models/InboundMessageContext' +jest.mock('../MessageSender') + class CustomProtocolMessage extends AgentMessage { public readonly type = CustomProtocolMessage.type.messageTypeUri public static readonly type = parseMessageType('https://didcomm.org/fake-protocol/1.5/message') + + public constructor(options: { id?: string }) { + super() + + if (options) { + this.id = options.id ?? this.generateId() + } + } } describe('Dispatcher', () => { @@ -29,7 +42,7 @@ describe('Dispatcher', () => { messageHandlerRegistry, agentConfig.logger ) - const customProtocolMessage = new CustomProtocolMessage() + const customProtocolMessage = new CustomProtocolMessage({}) const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext }) const mockHandle = jest.fn() @@ -48,15 +61,218 @@ describe('Dispatcher', () => { new MessageHandlerRegistry(), agentConfig.logger ) - const customProtocolMessage = new CustomProtocolMessage() + const customProtocolMessage = new CustomProtocolMessage({ + id: '55170d10-b91f-4df2-9dcd-6deb4e806c1b', + }) const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext }) const mockHandle = jest.fn() messageHandlerRegistry.registerMessageHandler({ supportedMessages: [], handle: mockHandle }) await expect(dispatcher.dispatch(inboundMessageContext)).rejects.toThrow( - 'No handler for message type "https://didcomm.org/fake-protocol/1.5/message" found' + 'Error handling message 55170d10-b91f-4df2-9dcd-6deb4e806c1b with type https://didcomm.org/fake-protocol/1.5/message. The message type is not supported' + ) + }) + + it('calls the middleware in the order they are registered', async () => { + const agentContext = getAgentContext() + + const dispatcher = new Dispatcher( + new MessageSenderMock(), + eventEmitter, + new MessageHandlerRegistry(), + agentConfig.logger + ) + + const customProtocolMessage = new CustomProtocolMessage({ + id: '55170d10-b91f-4df2-9dcd-6deb4e806c1b', + }) + const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext }) + + const firstMiddleware = jest.fn().mockImplementation(async (_, next) => next()) + const secondMiddleware = jest.fn() + agentContext.dependencyManager.registerMessageHandlerMiddleware(firstMiddleware) + agentContext.dependencyManager.registerMessageHandlerMiddleware(secondMiddleware) + + await dispatcher.dispatch(inboundMessageContext) + + expect(firstMiddleware).toHaveBeenCalled() + expect(secondMiddleware).toHaveBeenCalled() + + // Verify the order of calls + const firstMiddlewareCallOrder = firstMiddleware.mock.invocationCallOrder[0] + const secondMiddlewareCallOrder = secondMiddleware.mock.invocationCallOrder[0] + expect(firstMiddlewareCallOrder).toBeLessThan(secondMiddlewareCallOrder) + }) + + it('calls the middleware in the order they are registered', async () => { + const agentContext = getAgentContext() + + const dispatcher = new Dispatcher( + new MessageSenderMock(), + eventEmitter, + new MessageHandlerRegistry(), + agentConfig.logger + ) + + const customProtocolMessage = new CustomProtocolMessage({ + id: '55170d10-b91f-4df2-9dcd-6deb4e806c1b', + }) + const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext }) + + const firstMiddleware = jest.fn().mockImplementation(async (_, next) => next()) + const secondMiddleware = jest.fn() + agentContext.dependencyManager.registerMessageHandlerMiddleware(firstMiddleware) + agentContext.dependencyManager.registerMessageHandlerMiddleware(secondMiddleware) + + await dispatcher.dispatch(inboundMessageContext) + + expect(firstMiddleware).toHaveBeenCalled() + expect(secondMiddleware).toHaveBeenCalled() + + // Verify the order of calls + const firstMiddlewareCallOrder = firstMiddleware.mock.invocationCallOrder[0] + const secondMiddlewareCallOrder = secondMiddleware.mock.invocationCallOrder[0] + expect(firstMiddlewareCallOrder).toBeLessThan(secondMiddlewareCallOrder) + }) + + it('correctly calls the fallback message handler if no message handler is registered for the message type', async () => { + const agentContext = getAgentContext() + + const dispatcher = new Dispatcher( + new MessageSenderMock(), + eventEmitter, + new MessageHandlerRegistry(), + agentConfig.logger + ) + + const customProtocolMessage = new CustomProtocolMessage({ + id: '55170d10-b91f-4df2-9dcd-6deb4e806c1b', + }) + const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext }) + + const fallbackMessageHandler = jest.fn() + agentContext.dependencyManager.setFallbackMessageHandler({ + supportedMessages: [], + handle: fallbackMessageHandler, + }) + + await dispatcher.dispatch(inboundMessageContext) + + expect(fallbackMessageHandler).toHaveBeenCalled() + }) + + it('will not call the message handler if the middleware does not call next (intercept incoming message handling)', async () => { + const messageHandlerRegistry = new MessageHandlerRegistry() + const agentContext = getAgentContext() + + const dispatcher = new Dispatcher( + new MessageSenderMock(), + eventEmitter, + messageHandlerRegistry, + agentConfig.logger ) + + const customProtocolMessage = new CustomProtocolMessage({ + id: '55170d10-b91f-4df2-9dcd-6deb4e806c1b', + }) + const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext }) + + const mockHandle = jest.fn() + messageHandlerRegistry.registerMessageHandler({ supportedMessages: [CustomProtocolMessage], handle: mockHandle }) + + const middleware = jest.fn() + agentContext.dependencyManager.registerMessageHandlerMiddleware(middleware) + await dispatcher.dispatch(inboundMessageContext) + expect(mockHandle).not.toHaveBeenCalled() + + // Not it should call it, as the middleware calls next + middleware.mockImplementationOnce((_, next) => next()) + await dispatcher.dispatch(inboundMessageContext) + expect(mockHandle).toHaveBeenCalled() + }) + + it('calls the message handler set by the middleware', async () => { + const agentContext = getAgentContext() + + const dispatcher = new Dispatcher( + new MessageSenderMock(), + eventEmitter, + new MessageHandlerRegistry(), + agentConfig.logger + ) + + const customProtocolMessage = new CustomProtocolMessage({ + id: '55170d10-b91f-4df2-9dcd-6deb4e806c1b', + }) + const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext }) + + const handle = jest.fn() + const middleware = jest + .fn() + .mockImplementationOnce(async (inboundMessageContext: InboundMessageContext, next) => { + inboundMessageContext.messageHandler = { + supportedMessages: [], + handle: handle, + } + + await next() + }) + + agentContext.dependencyManager.registerMessageHandlerMiddleware(middleware) + await dispatcher.dispatch(inboundMessageContext) + expect(middleware).toHaveBeenCalled() + expect(handle).toHaveBeenCalled() + }) + + it('sends the response message set by the middleware', async () => { + const agentContext = getAgentContext({ + agentConfig, + }) + const messageSenderMock = new MessageSenderMock() + + const dispatcher = new Dispatcher( + messageSenderMock, + eventEmitter, + new MessageHandlerRegistry(), + agentConfig.logger + ) + + const connectionMock = jest.fn() as unknown as ConnectionRecord + + const customProtocolMessage = new CustomProtocolMessage({ + id: '55170d10-b91f-4df2-9dcd-6deb4e806c1b', + }) + const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { + agentContext, + connection: connectionMock, + }) + + const middleware = jest.fn().mockImplementationOnce(async (inboundMessageContext: InboundMessageContext) => { + // We do not call next + inboundMessageContext.responseMessage = await getOutboundMessageContext(inboundMessageContext.agentContext, { + message: new CustomProtocolMessage({ + id: 'static-id', + }), + connectionRecord: inboundMessageContext.connection, + }) + }) + + agentContext.dependencyManager.registerMessageHandlerMiddleware(middleware) + await dispatcher.dispatch(inboundMessageContext) + expect(middleware).toHaveBeenCalled() + expect(messageSenderMock.sendMessage).toHaveBeenCalledWith({ + inboundMessageContext, + agentContext, + associatedRecord: undefined, + connection: connectionMock, + message: new CustomProtocolMessage({ + id: 'static-id', + }), + outOfBand: undefined, + serviceParams: undefined, + sessionId: undefined, + }) }) }) }) diff --git a/packages/core/src/agent/models/InboundMessageContext.ts b/packages/core/src/agent/models/InboundMessageContext.ts index 03bfef54f3..886210f5f5 100644 --- a/packages/core/src/agent/models/InboundMessageContext.ts +++ b/packages/core/src/agent/models/InboundMessageContext.ts @@ -1,6 +1,8 @@ +import type { OutboundMessageContext } from './OutboundMessageContext' import type { Key } from '../../crypto' import type { ConnectionRecord } from '../../modules/connections' import type { AgentMessage } from '../AgentMessage' +import type { MessageHandler } from '../MessageHandler' import type { AgentContext } from '../context' import { CredoError } from '../../error' @@ -15,14 +17,18 @@ export interface MessageContextParams { } export class InboundMessageContext { - public message: T public connection?: ConnectionRecord public sessionId?: string public senderKey?: Key public recipientKey?: Key public receivedAt: Date + public readonly agentContext: AgentContext + public message: T + public messageHandler?: MessageHandler + public responseMessage?: OutboundMessageContext + public constructor(message: T, context: MessageContextParams) { this.message = message this.recipientKey = context.recipientKey @@ -33,6 +39,14 @@ export class InboundMessageContext { this.receivedAt = context.receivedAt ?? new Date() } + public setMessageHandler(messageHandler: MessageHandler) { + this.messageHandler = messageHandler + } + + public setResponseMessage(outboundMessageContext: OutboundMessageContext) { + this.responseMessage = outboundMessageContext + } + /** * Assert the inbound message has a ready connection associated with it. * diff --git a/packages/core/src/modules/connections/ConnectionsApi.ts b/packages/core/src/modules/connections/ConnectionsApi.ts index 634fb51d16..c90b9561eb 100644 --- a/packages/core/src/modules/connections/ConnectionsApi.ts +++ b/packages/core/src/modules/connections/ConnectionsApi.ts @@ -277,11 +277,11 @@ export class ConnectionsApi { * @param connectionId the id of the connection for which to accept the response * @param responseRequested do we want a response to our ping * @param withReturnRouting do we want a response at the time of posting - * @returns TurstPingMessage + * @returns TrustPingMessage */ public async sendPing( connectionId: string, - { responseRequested = true, withReturnRouting = undefined }: SendPingOptions + { responseRequested = true, withReturnRouting = undefined }: SendPingOptions = {} ) { const connection = await this.getById(connectionId) diff --git a/packages/core/src/plugins/DependencyManager.ts b/packages/core/src/plugins/DependencyManager.ts index 166bc1380c..542c529810 100644 --- a/packages/core/src/plugins/DependencyManager.ts +++ b/packages/core/src/plugins/DependencyManager.ts @@ -1,5 +1,6 @@ import type { ModulesMap } from '../agent/AgentModules' import type { MessageHandler } from '../agent/MessageHandler' +import type { MessageHandlerMiddleware } from '../agent/MessageHandlerMiddleware' import type { Constructor } from '../utils/mixins' import type { DependencyContainer } from 'tsyringe' @@ -15,6 +16,9 @@ export class DependencyManager { public readonly container: DependencyContainer public readonly registeredModules: ModulesMap + public readonly messageHandlerMiddlewares: MessageHandlerMiddleware[] = [] + private _fallbackMessageHandler?: MessageHandler + public constructor( container: DependencyContainer = rootContainer.createChildContainer(), registeredModules: ModulesMap = {} @@ -49,6 +53,22 @@ export class DependencyManager { } } + public registerMessageHandlerMiddleware(messageHandlerMiddleware: MessageHandlerMiddleware) { + this.messageHandlerMiddlewares.push(messageHandlerMiddleware) + } + + public get fallbackMessageHandler() { + return this._fallbackMessageHandler + } + + /** + * Sets the fallback message handler, the message handler that will be called if no handler + * is registered for an incoming message type. + */ + public setFallbackMessageHandler(fallbackMessageHandler: MessageHandler) { + this._fallbackMessageHandler = fallbackMessageHandler + } + public registerSingleton(from: InjectionToken, to: InjectionToken): void public registerSingleton(token: Constructor): void // eslint-disable-next-line @typescript-eslint/no-explicit-any diff --git a/packages/core/tests/middleware.test.ts b/packages/core/tests/middleware.test.ts new file mode 100644 index 0000000000..77088122ec --- /dev/null +++ b/packages/core/tests/middleware.test.ts @@ -0,0 +1,127 @@ +import type { SubjectMessage } from '../../../tests/transport/SubjectInboundTransport' +import type { ConnectionRecord, InboundMessageContext, MessageHandler } from '../src' + +import { Subject } from 'rxjs' + +import { SubjectInboundTransport } from '../../../tests/transport/SubjectInboundTransport' +import { SubjectOutboundTransport } from '../../../tests/transport/SubjectOutboundTransport' +import { + TrustPingResponseMessage, + BasicMessage, + getOutboundMessageContext, + MessageSender, + AgentMessage, + JsonTransformer, + Agent, +} from '../src' + +import { + getInMemoryAgentOptions, + makeConnection, + waitForAgentMessageProcessedEvent, + waitForBasicMessage, +} from './helpers' + +const faberConfig = getInMemoryAgentOptions('Faber Message Handler Middleware', { + endpoints: ['rxjs:faber'], +}) + +const aliceConfig = getInMemoryAgentOptions('Alice Message Handler Middleware', { + endpoints: ['rxjs:alice'], +}) + +class FallbackMessageHandler implements MessageHandler { + public supportedMessages = [] + + public async handle(messageContext: InboundMessageContext) { + return getOutboundMessageContext(messageContext.agentContext, { + connectionRecord: messageContext.connection, + message: new BasicMessage({ + content: "Hey there, I'm not sure I understand the message you sent to me", + }), + }) + } +} + +describe('Message Handler Middleware E2E', () => { + let faberAgent: Agent + let aliceAgent: Agent + let faberConnection: ConnectionRecord + // eslint-disable-next-line @typescript-eslint/no-unused-vars + let aliceConnection: ConnectionRecord + + beforeEach(async () => { + const faberMessages = new Subject() + const aliceMessages = new Subject() + const subjectMap = { + 'rxjs:faber': faberMessages, + 'rxjs:alice': aliceMessages, + } + + faberAgent = new Agent(faberConfig) + faberAgent.registerInboundTransport(new SubjectInboundTransport(faberMessages)) + faberAgent.registerOutboundTransport(new SubjectOutboundTransport(subjectMap)) + await faberAgent.initialize() + + aliceAgent = new Agent(aliceConfig) + aliceAgent.registerInboundTransport(new SubjectInboundTransport(aliceMessages)) + aliceAgent.registerOutboundTransport(new SubjectOutboundTransport(subjectMap)) + await aliceAgent.initialize() + ;[aliceConnection, faberConnection] = await makeConnection(aliceAgent, faberAgent) + }) + + afterEach(async () => { + await faberAgent.shutdown() + await faberAgent.wallet.delete() + await aliceAgent.shutdown() + await aliceAgent.wallet.delete() + }) + + test('Correctly calls the fallback message handler if no message handler is defined', async () => { + // Fallback message handler + aliceAgent.dependencyManager.setFallbackMessageHandler(new FallbackMessageHandler()) + + const message = JsonTransformer.fromJSON( + { + '@type': 'https://credo.js.org/custom-messaging/1.0/say-hello', + '@id': 'b630b69a-2b82-4764-87ba-56aa2febfb97', + }, + AgentMessage + ) + + // Send a custom message + const messageSender = faberAgent.dependencyManager.resolve(MessageSender) + const outboundMessageContext = await getOutboundMessageContext(faberAgent.context, { + connectionRecord: faberConnection, + message, + }) + await messageSender.sendMessage(outboundMessageContext) + + // Expect the basic message sent by the fallback message handler + await waitForBasicMessage(faberAgent, { + content: "Hey there, I'm not sure I understand the message you sent to me", + }) + }) + + test('Correctly calls the registered message handler middleware', async () => { + aliceAgent.dependencyManager.registerMessageHandlerMiddleware( + async (inboundMessageContext: InboundMessageContext, next) => { + await next() + + if (inboundMessageContext.responseMessage) { + inboundMessageContext.responseMessage.message.setTiming({ + outTime: new Date('2021-01-01'), + }) + } + } + ) + + await faberAgent.connections.sendPing(faberConnection.id, {}) + const receiveMessage = await waitForAgentMessageProcessedEvent(faberAgent, { + messageType: TrustPingResponseMessage.type.messageTypeUri, + }) + + // Should have sent the message with the timing added in the middleware + expect(receiveMessage.timing?.outTime).toEqual(new Date('2021-01-01')) + }) +}) From 17150ab896cb2edeb001a9bdb2de46b5eb2f66f6 Mon Sep 17 00:00:00 2001 From: Timo Glastra Date: Sat, 8 Jun 2024 18:03:07 +0200 Subject: [PATCH 2/3] simplify api Signed-off-by: Timo Glastra --- packages/core/src/agent/Dispatcher.ts | 11 ++++++--- .../src/agent/__tests__/Dispatcher.test.ts | 5 +--- .../core/src/plugins/DependencyManager.ts | 4 ++-- packages/core/tests/middleware.test.ts | 24 +++++++------------ 4 files changed, 20 insertions(+), 24 deletions(-) diff --git a/packages/core/src/agent/Dispatcher.ts b/packages/core/src/agent/Dispatcher.ts index 4f845f49d5..b42f9aa6ca 100644 --- a/packages/core/src/agent/Dispatcher.ts +++ b/packages/core/src/agent/Dispatcher.ts @@ -38,9 +38,14 @@ class Dispatcher { } private defaultHandlerMiddleware: MessageHandlerMiddleware = async (inboundMessageContext, next) => { - const messageHandler = - inboundMessageContext.messageHandler ?? - inboundMessageContext.agentContext.dependencyManager.fallbackMessageHandler + let messageHandler = inboundMessageContext.messageHandler + + if (!messageHandler && inboundMessageContext.agentContext.dependencyManager.fallbackMessageHandler) { + messageHandler = { + supportedMessages: [], + handle: inboundMessageContext.agentContext.dependencyManager.fallbackMessageHandler, + } + } if (!messageHandler) { throw new ProblemReportError( diff --git a/packages/core/src/agent/__tests__/Dispatcher.test.ts b/packages/core/src/agent/__tests__/Dispatcher.test.ts index 6257f34515..7f8f2513c5 100644 --- a/packages/core/src/agent/__tests__/Dispatcher.test.ts +++ b/packages/core/src/agent/__tests__/Dispatcher.test.ts @@ -152,10 +152,7 @@ describe('Dispatcher', () => { const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext }) const fallbackMessageHandler = jest.fn() - agentContext.dependencyManager.setFallbackMessageHandler({ - supportedMessages: [], - handle: fallbackMessageHandler, - }) + agentContext.dependencyManager.setFallbackMessageHandler(fallbackMessageHandler) await dispatcher.dispatch(inboundMessageContext) diff --git a/packages/core/src/plugins/DependencyManager.ts b/packages/core/src/plugins/DependencyManager.ts index 542c529810..844a1dc480 100644 --- a/packages/core/src/plugins/DependencyManager.ts +++ b/packages/core/src/plugins/DependencyManager.ts @@ -17,7 +17,7 @@ export class DependencyManager { public readonly registeredModules: ModulesMap public readonly messageHandlerMiddlewares: MessageHandlerMiddleware[] = [] - private _fallbackMessageHandler?: MessageHandler + private _fallbackMessageHandler?: MessageHandler['handle'] public constructor( container: DependencyContainer = rootContainer.createChildContainer(), @@ -65,7 +65,7 @@ export class DependencyManager { * Sets the fallback message handler, the message handler that will be called if no handler * is registered for an incoming message type. */ - public setFallbackMessageHandler(fallbackMessageHandler: MessageHandler) { + public setFallbackMessageHandler(fallbackMessageHandler: MessageHandler['handle']) { this._fallbackMessageHandler = fallbackMessageHandler } diff --git a/packages/core/tests/middleware.test.ts b/packages/core/tests/middleware.test.ts index 77088122ec..cf76ca8031 100644 --- a/packages/core/tests/middleware.test.ts +++ b/packages/core/tests/middleware.test.ts @@ -1,5 +1,5 @@ import type { SubjectMessage } from '../../../tests/transport/SubjectInboundTransport' -import type { ConnectionRecord, InboundMessageContext, MessageHandler } from '../src' +import type { ConnectionRecord, InboundMessageContext } from '../src' import { Subject } from 'rxjs' @@ -30,19 +30,6 @@ const aliceConfig = getInMemoryAgentOptions('Alice Message Handler Middleware', endpoints: ['rxjs:alice'], }) -class FallbackMessageHandler implements MessageHandler { - public supportedMessages = [] - - public async handle(messageContext: InboundMessageContext) { - return getOutboundMessageContext(messageContext.agentContext, { - connectionRecord: messageContext.connection, - message: new BasicMessage({ - content: "Hey there, I'm not sure I understand the message you sent to me", - }), - }) - } -} - describe('Message Handler Middleware E2E', () => { let faberAgent: Agent let aliceAgent: Agent @@ -79,7 +66,14 @@ describe('Message Handler Middleware E2E', () => { test('Correctly calls the fallback message handler if no message handler is defined', async () => { // Fallback message handler - aliceAgent.dependencyManager.setFallbackMessageHandler(new FallbackMessageHandler()) + aliceAgent.dependencyManager.setFallbackMessageHandler((messageContext) => { + return getOutboundMessageContext(messageContext.agentContext, { + connectionRecord: messageContext.connection, + message: new BasicMessage({ + content: "Hey there, I'm not sure I understand the message you sent to me", + }), + }) + }) const message = JsonTransformer.fromJSON( { From 527d1610ee18ece52f286becf87ed51db0d6551d Mon Sep 17 00:00:00 2001 From: Timo Glastra Date: Sat, 8 Jun 2024 18:17:26 +0200 Subject: [PATCH 3/3] small cleanup Signed-off-by: Timo Glastra --- packages/core/src/agent/MessageHandlerMiddleware.ts | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/packages/core/src/agent/MessageHandlerMiddleware.ts b/packages/core/src/agent/MessageHandlerMiddleware.ts index 8b672cd08a..9eff446bbd 100644 --- a/packages/core/src/agent/MessageHandlerMiddleware.ts +++ b/packages/core/src/agent/MessageHandlerMiddleware.ts @@ -7,13 +7,12 @@ export interface MessageHandlerMiddleware { export class MessageHandlerMiddlewareRunner { public static async run(middlewares: MessageHandlerMiddleware[], inboundMessageContext: InboundMessageContext) { const compose = (middlewares: MessageHandlerMiddleware[]) => { - return async function (inboundMessageContext: InboundMessageContext, next?: () => Promise) { + return async function (inboundMessageContext: InboundMessageContext) { let index = -1 async function dispatch(i: number): Promise { if (i <= index) throw new Error('next() called multiple times') index = i - let fn: MessageHandlerMiddleware | undefined = middlewares[i] - if (i === middlewares.length) fn = next + const fn = middlewares[i] if (!fn) return await fn(inboundMessageContext, () => dispatch(i + 1)) }