From ff8e31114afdf27f12e7874c244df14b90c6f91d Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 10 Feb 2023 17:37:43 +1100 Subject: [PATCH] feat: client static middleware registration Related #501 Related #502 [ci skip] --- src/RPC/RPCClient.ts | 83 +++++++++----------- src/RPC/RPCServer.ts | 2 +- src/RPC/utils.ts | 49 +++++++++++- tests/RPC/RPCClient.test.ts | 53 +++++++------ tests/RPC/RPCServer.test.ts | 6 +- tests/clientRPC/handlers/agentUnlock.test.ts | 8 +- 6 files changed, 117 insertions(+), 84 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 50545ec3b5..e75d1b54b4 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -17,6 +17,10 @@ import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; +import { + clientInputTransformStream, + clientOutputTransformStream, +} from './utils'; import { never } from '../utils'; // eslint-disable-next-line @@ -26,16 +30,24 @@ class RPCClient { static async createRPCClient({ manifest, streamPairCreateCallback, + middleware = rpcUtils.defaultClientMiddlewareWrapper(), logger = new Logger(this.name), }: { manifest: M; streamPairCreateCallback: StreamPairCreateCallback; - logger: Logger; + middleware?: MiddlewareFactory< + Uint8Array, + JsonRpcRequest, + JsonRpcResponse, + Uint8Array + >; + logger?: Logger; }) { logger.info(`Creating ${this.name}`); const rpcClient = new this({ manifest, streamPairCreateCallback, + middleware, logger, }); logger.info(`Created ${this.name}`); @@ -44,6 +56,12 @@ class RPCClient { protected logger: Logger; protected streamPairCreateCallback: StreamPairCreateCallback; + protected middleware: MiddlewareFactory< + Uint8Array, + JsonRpcRequest, + JsonRpcResponse, + Uint8Array + >; protected callerTypes: Record; // Method proxies public readonly methodsProxy = new Proxy( @@ -90,14 +108,22 @@ class RPCClient { public constructor({ manifest, streamPairCreateCallback, + middleware, logger, }: { manifest: M; streamPairCreateCallback: StreamPairCreateCallback; + middleware: MiddlewareFactory< + Uint8Array, + JsonRpcRequest, + JsonRpcResponse, + Uint8Array + >; logger: Logger; }) { this.callerTypes = rpcUtils.getHandlerTypes(manifest); this.streamPairCreateCallback = streamPairCreateCallback; + this.middleware = middleware; this.logger = logger; } @@ -199,36 +225,23 @@ class RPCClient { public async rawDuplexStreamCaller( method: string, ): Promise> { - // Creating caller side transforms - const outputMessageTransforStream = - rpcUtils.clientOutputTransformStream(); - const inputMessageTransformStream = - rpcUtils.clientInputTransformStream(method); - let reverseStream = outputMessageTransforStream.writable; - let forwardStream = inputMessageTransformStream.readable; - // Setting up middleware chains - for (const middlewareFactory of this.middleware) { - const middleware = middlewareFactory(); - forwardStream = forwardStream.pipeThrough(middleware.forward); - void middleware.reverse.readable.pipeTo(reverseStream).catch(() => {}); - reverseStream = middleware.reverse.writable; - } + const outputMessageTransformStream = clientOutputTransformStream(); + const inputMessageTransformStream = clientInputTransformStream(method); + const middleware = this.middleware(); // Hooking up agnostic stream side const streamPair = await this.streamPairCreateCallback(); void streamPair.readable - .pipeThrough( - rpcUtils.binaryToJsonMessageStream(rpcUtils.parseJsonRpcResponse), - ) - .pipeTo(reverseStream) + .pipeThrough(middleware.reverse) + .pipeTo(outputMessageTransformStream.writable) .catch(() => {}); - void forwardStream - .pipeThrough(rpcUtils.jsonMessageToBinaryStream()) + void inputMessageTransformStream.readable + .pipeThrough(middleware.forward) .pipeTo(streamPair.writable) .catch(() => {}); // Returning interface return { - readable: outputMessageTransforStream.readable, + readable: outputMessageTransformStream.readable, writable: inputMessageTransformStream.writable, }; } @@ -273,32 +286,6 @@ class RPCClient { writable: callerInterface.writable, }; } - - protected middleware: Array< - MiddlewareFactory< - JsonRpcRequest, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcResponse - > - > = []; - - @ready(new rpcErrors.ErrorRpcDestroyed()) - public registerMiddleware( - middlewareFactory: MiddlewareFactory< - JsonRpcRequest, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcResponse - >, - ) { - this.middleware.push(middlewareFactory); - } - - @ready(new rpcErrors.ErrorRpcDestroyed()) - public clearMiddleware() { - this.middleware = []; - } } export default RPCClient; diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index 78713736f7..a25813ae7e 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -37,7 +37,7 @@ interface RPCServer extends CreateDestroy {} class RPCServer { static async createRPCServer({ manifest, - middleware = rpcUtils.defaultMiddlewareWrapper(), + middleware = rpcUtils.defaultServerMiddlewareWrapper(), logger = new Logger(this.name), }: { manifest: ServerManifest; diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index dad91d1320..4a069e3f12 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -586,7 +586,7 @@ const defaultMiddleware: MiddlewareFactory< }; }; -const defaultMiddlewareWrapper = ( +const defaultServerMiddlewareWrapper = ( middleware: MiddlewareFactory< JsonRpcRequest, JsonRpcRequest, @@ -627,6 +627,50 @@ const defaultMiddlewareWrapper = ( }; }; +const defaultClientMiddlewareWrapper = ( + middleware: MiddlewareFactory< + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse + > = defaultMiddleware, +): MiddlewareFactory< + Uint8Array, + JsonRpcRequest, + JsonRpcResponse, + Uint8Array +> => { + return () => { + const outputTransformStream = binaryToJsonMessageStream( + parseJsonRpcResponse, + undefined, + ); + const inputTransformStream = new TransformStream< + JsonRpcRequest, + JsonRpcRequest + >(); + + const middleMiddleware = middleware(); + const forwardReadable = inputTransformStream.readable + .pipeThrough(middleMiddleware.forward) // Usual middleware here + .pipeThrough(jsonMessageToBinaryStream()); + const reverseReadable = outputTransformStream.readable.pipeThrough( + middleMiddleware.reverse, + ); // Usual middleware here + + return { + forward: { + readable: forwardReadable, + writable: inputTransformStream.writable, + }, + reverse: { + readable: reverseReadable, + writable: outputTransformStream.writable, + }, + }; + }; +}; + export { binaryToJsonMessageStream, jsonMessageToBinaryStream, @@ -648,5 +692,6 @@ export { extractFirstMessageTransform, getHandlerTypes, defaultMiddleware, - defaultMiddlewareWrapper, + defaultServerMiddlewareWrapper, + defaultClientMiddlewareWrapper, }; diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index ae2fda61e8..d675d407e9 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -18,6 +18,7 @@ import { ServerCaller, UnaryCaller, } from '@/RPC/callers'; +import * as rpcUtils from '@/RPC/utils'; import * as rpcTestUtils from './utils'; describe(`${RPCClient.name}`, () => { @@ -309,22 +310,22 @@ describe(`${RPCClient.name}`, () => { const rpcClient = await RPCClient.createRPCClient({ manifest: {}, streamPairCreateCallback: async () => streamPair, + middleware: rpcUtils.defaultClientMiddlewareWrapper(() => { + return { + forward: new TransformStream({ + transform: (chunk, controller) => { + controller.enqueue({ + ...chunk, + params: 'one', + }); + }, + }), + reverse: new TransformStream(), + }; + }), logger, }); - rpcClient.registerMiddleware(() => { - return { - forward: new TransformStream({ - transform: (chunk, controller) => { - controller.enqueue({ - ...chunk, - params: 'one', - }); - }, - }), - reverse: new TransformStream(), - }; - }); const callerInterface = await rpcClient.rawDuplexStreamCaller< JSONValue, JSONValue @@ -373,22 +374,22 @@ describe(`${RPCClient.name}`, () => { const rpcClient = await RPCClient.createRPCClient({ manifest: {}, streamPairCreateCallback: async () => streamPair, + middleware: rpcUtils.defaultClientMiddlewareWrapper(() => { + return { + forward: new TransformStream(), + reverse: new TransformStream({ + transform: (chunk, controller) => { + controller.enqueue({ + ...chunk, + result: 'one', + }); + }, + }), + }; + }), logger, }); - rpcClient.registerMiddleware(() => { - return { - forward: new TransformStream(), - reverse: new TransformStream({ - transform: (chunk, controller) => { - controller.enqueue({ - ...chunk, - result: 'one', - }); - }, - }), - }; - }); const callerInterface = await rpcClient.rawDuplexStreamCaller< JSONValue, JSONValue diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index a613e2ec54..02249df46c 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -456,7 +456,7 @@ describe(`${RPCServer.name}`, () => { } }; } - const middleware = rpcUtils.defaultMiddlewareWrapper(() => { + const middleware = rpcUtils.defaultServerMiddlewareWrapper(() => { return { forward: new TransformStream({ transform: (chunk, controller) => { @@ -501,7 +501,7 @@ describe(`${RPCServer.name}`, () => { } }; } - const middleware = rpcUtils.defaultMiddlewareWrapper(() => { + const middleware = rpcUtils.defaultServerMiddlewareWrapper(() => { return { forward: new TransformStream(), reverse: new TransformStream({ @@ -549,7 +549,7 @@ describe(`${RPCServer.name}`, () => { } }; } - const middleware = rpcUtils.defaultMiddlewareWrapper(() => { + const middleware = rpcUtils.defaultServerMiddlewareWrapper(() => { let first = true; let reverseController: TransformStreamDefaultController; return { diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts index 1a533a0701..45939e327d 100644 --- a/tests/clientRPC/handlers/agentUnlock.test.ts +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -97,7 +97,7 @@ describe('agentUnlock', () => { manifest: { agentUnlock: new AgentUnlockHandler({ logger }), }, - middleware: rpcUtils.defaultMiddlewareWrapper( + middleware: rpcUtils.defaultServerMiddlewareWrapper( clientRPCUtils.authenticationMiddlewareServer(sessionManager, keyRing), ), logger, @@ -118,11 +118,11 @@ describe('agentUnlock', () => { logger.getChild('client'), ); }, + middleware: rpcUtils.defaultClientMiddlewareWrapper( + clientRPCUtils.authenticationMiddlewareClient(session), + ), logger, }); - rpcClient.registerMiddleware( - clientRPCUtils.authenticationMiddlewareClient(session), - ); // Doing the test const result = await rpcClient.methods.agentUnlock({