Skip to content

Commit

Permalink
feat: client static middleware registration
Browse files Browse the repository at this point in the history
Related #501
Related #502

[ci skip]
  • Loading branch information
tegefaulkes committed Feb 10, 2023
1 parent 0489e51 commit ff8e311
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 84 deletions.
83 changes: 35 additions & 48 deletions src/RPC/RPCClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy';
import Logger from '@matrixai/logger';
import * as rpcErrors from './errors';
import * as rpcUtils from './utils';
import {
clientInputTransformStream,
clientOutputTransformStream,
} from './utils';
import { never } from '../utils';

// eslint-disable-next-line
Expand All @@ -26,16 +30,24 @@ class RPCClient<M extends ClientManifest> {
static async createRPCClient<M extends ClientManifest>({
manifest,
streamPairCreateCallback,
middleware = rpcUtils.defaultClientMiddlewareWrapper(),
logger = new Logger(this.name),
}: {
manifest: M;
streamPairCreateCallback: StreamPairCreateCallback;
logger: Logger;
middleware?: MiddlewareFactory<
Uint8Array,
JsonRpcRequest,
JsonRpcResponse,
Uint8Array
>;
logger?: Logger;
}) {
logger.info(`Creating ${this.name}`);
const rpcClient = new this({
manifest,
streamPairCreateCallback,
middleware,
logger,
});
logger.info(`Created ${this.name}`);
Expand All @@ -44,6 +56,12 @@ class RPCClient<M extends ClientManifest> {

protected logger: Logger;
protected streamPairCreateCallback: StreamPairCreateCallback;
protected middleware: MiddlewareFactory<
Uint8Array,
JsonRpcRequest,
JsonRpcResponse,
Uint8Array
>;
protected callerTypes: Record<string, HandlerType>;
// Method proxies
public readonly methodsProxy = new Proxy(
Expand Down Expand Up @@ -90,14 +108,22 @@ class RPCClient<M extends ClientManifest> {
public constructor({
manifest,
streamPairCreateCallback,
middleware,
logger,
}: {
manifest: M;
streamPairCreateCallback: StreamPairCreateCallback;
middleware: MiddlewareFactory<
Uint8Array,
JsonRpcRequest,
JsonRpcResponse,
Uint8Array
>;
logger: Logger;
}) {
this.callerTypes = rpcUtils.getHandlerTypes(manifest);
this.streamPairCreateCallback = streamPairCreateCallback;
this.middleware = middleware;
this.logger = logger;
}

Expand Down Expand Up @@ -199,36 +225,23 @@ class RPCClient<M extends ClientManifest> {
public async rawDuplexStreamCaller<I extends JSONValue, O extends JSONValue>(
method: string,
): Promise<ReadableWritablePair<O, I>> {
// Creating caller side transforms
const outputMessageTransforStream =
rpcUtils.clientOutputTransformStream<O>();
const inputMessageTransformStream =
rpcUtils.clientInputTransformStream<I>(method);
let reverseStream = outputMessageTransforStream.writable;
let forwardStream = inputMessageTransformStream.readable;
// Setting up middleware chains
for (const middlewareFactory of this.middleware) {
const middleware = middlewareFactory();
forwardStream = forwardStream.pipeThrough(middleware.forward);
void middleware.reverse.readable.pipeTo(reverseStream).catch(() => {});
reverseStream = middleware.reverse.writable;
}
const outputMessageTransformStream = clientOutputTransformStream<O>();
const inputMessageTransformStream = clientInputTransformStream<I>(method);
const middleware = this.middleware();
// Hooking up agnostic stream side
const streamPair = await this.streamPairCreateCallback();
void streamPair.readable
.pipeThrough(
rpcUtils.binaryToJsonMessageStream(rpcUtils.parseJsonRpcResponse),
)
.pipeTo(reverseStream)
.pipeThrough(middleware.reverse)
.pipeTo(outputMessageTransformStream.writable)
.catch(() => {});
void forwardStream
.pipeThrough(rpcUtils.jsonMessageToBinaryStream())
void inputMessageTransformStream.readable
.pipeThrough(middleware.forward)
.pipeTo(streamPair.writable)
.catch(() => {});

// Returning interface
return {
readable: outputMessageTransforStream.readable,
readable: outputMessageTransformStream.readable,
writable: inputMessageTransformStream.writable,
};
}
Expand Down Expand Up @@ -273,32 +286,6 @@ class RPCClient<M extends ClientManifest> {
writable: callerInterface.writable,
};
}

protected middleware: Array<
MiddlewareFactory<
JsonRpcRequest,
JsonRpcRequest,
JsonRpcResponse,
JsonRpcResponse
>
> = [];

@ready(new rpcErrors.ErrorRpcDestroyed())
public registerMiddleware(
middlewareFactory: MiddlewareFactory<
JsonRpcRequest,
JsonRpcRequest,
JsonRpcResponse,
JsonRpcResponse
>,
) {
this.middleware.push(middlewareFactory);
}

@ready(new rpcErrors.ErrorRpcDestroyed())
public clearMiddleware() {
this.middleware = [];
}
}

export default RPCClient;
2 changes: 1 addition & 1 deletion src/RPC/RPCServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ interface RPCServer extends CreateDestroy {}
class RPCServer {
static async createRPCServer({
manifest,
middleware = rpcUtils.defaultMiddlewareWrapper(),
middleware = rpcUtils.defaultServerMiddlewareWrapper(),
logger = new Logger(this.name),
}: {
manifest: ServerManifest;
Expand Down
49 changes: 47 additions & 2 deletions src/RPC/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ const defaultMiddleware: MiddlewareFactory<
};
};

const defaultMiddlewareWrapper = (
const defaultServerMiddlewareWrapper = (
middleware: MiddlewareFactory<
JsonRpcRequest,
JsonRpcRequest,
Expand Down Expand Up @@ -627,6 +627,50 @@ const defaultMiddlewareWrapper = (
};
};

const defaultClientMiddlewareWrapper = (
middleware: MiddlewareFactory<
JsonRpcRequest,
JsonRpcRequest,
JsonRpcResponse,
JsonRpcResponse
> = defaultMiddleware,
): MiddlewareFactory<
Uint8Array,
JsonRpcRequest,
JsonRpcResponse,
Uint8Array
> => {
return () => {
const outputTransformStream = binaryToJsonMessageStream(
parseJsonRpcResponse,
undefined,
);
const inputTransformStream = new TransformStream<
JsonRpcRequest,
JsonRpcRequest
>();

const middleMiddleware = middleware();
const forwardReadable = inputTransformStream.readable
.pipeThrough(middleMiddleware.forward) // Usual middleware here
.pipeThrough(jsonMessageToBinaryStream());
const reverseReadable = outputTransformStream.readable.pipeThrough(
middleMiddleware.reverse,
); // Usual middleware here

return {
forward: {
readable: forwardReadable,
writable: inputTransformStream.writable,
},
reverse: {
readable: reverseReadable,
writable: outputTransformStream.writable,
},
};
};
};

export {
binaryToJsonMessageStream,
jsonMessageToBinaryStream,
Expand All @@ -648,5 +692,6 @@ export {
extractFirstMessageTransform,
getHandlerTypes,
defaultMiddleware,
defaultMiddlewareWrapper,
defaultServerMiddlewareWrapper,
defaultClientMiddlewareWrapper,
};
53 changes: 27 additions & 26 deletions tests/RPC/RPCClient.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
ServerCaller,
UnaryCaller,
} from '@/RPC/callers';
import * as rpcUtils from '@/RPC/utils';
import * as rpcTestUtils from './utils';

describe(`${RPCClient.name}`, () => {
Expand Down Expand Up @@ -309,22 +310,22 @@ describe(`${RPCClient.name}`, () => {
const rpcClient = await RPCClient.createRPCClient({
manifest: {},
streamPairCreateCallback: async () => streamPair,
middleware: rpcUtils.defaultClientMiddlewareWrapper(() => {
return {
forward: new TransformStream<JsonRpcRequest, JsonRpcRequest>({
transform: (chunk, controller) => {
controller.enqueue({
...chunk,
params: 'one',
});
},
}),
reverse: new TransformStream(),
};
}),
logger,
});

rpcClient.registerMiddleware(() => {
return {
forward: new TransformStream<JsonRpcRequest, JsonRpcRequest>({
transform: (chunk, controller) => {
controller.enqueue({
...chunk,
params: 'one',
});
},
}),
reverse: new TransformStream(),
};
});
const callerInterface = await rpcClient.rawDuplexStreamCaller<
JSONValue,
JSONValue
Expand Down Expand Up @@ -373,22 +374,22 @@ describe(`${RPCClient.name}`, () => {
const rpcClient = await RPCClient.createRPCClient({
manifest: {},
streamPairCreateCallback: async () => streamPair,
middleware: rpcUtils.defaultClientMiddlewareWrapper(() => {
return {
forward: new TransformStream(),
reverse: new TransformStream<JsonRpcResponse, JsonRpcResponse>({
transform: (chunk, controller) => {
controller.enqueue({
...chunk,
result: 'one',
});
},
}),
};
}),
logger,
});

rpcClient.registerMiddleware(() => {
return {
forward: new TransformStream(),
reverse: new TransformStream<JsonRpcResponse, JsonRpcResponse>({
transform: (chunk, controller) => {
controller.enqueue({
...chunk,
result: 'one',
});
},
}),
};
});
const callerInterface = await rpcClient.rawDuplexStreamCaller<
JSONValue,
JSONValue
Expand Down
6 changes: 3 additions & 3 deletions tests/RPC/RPCServer.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ describe(`${RPCServer.name}`, () => {
}
};
}
const middleware = rpcUtils.defaultMiddlewareWrapper(() => {
const middleware = rpcUtils.defaultServerMiddlewareWrapper(() => {
return {
forward: new TransformStream({
transform: (chunk, controller) => {
Expand Down Expand Up @@ -501,7 +501,7 @@ describe(`${RPCServer.name}`, () => {
}
};
}
const middleware = rpcUtils.defaultMiddlewareWrapper(() => {
const middleware = rpcUtils.defaultServerMiddlewareWrapper(() => {
return {
forward: new TransformStream(),
reverse: new TransformStream({
Expand Down Expand Up @@ -549,7 +549,7 @@ describe(`${RPCServer.name}`, () => {
}
};
}
const middleware = rpcUtils.defaultMiddlewareWrapper(() => {
const middleware = rpcUtils.defaultServerMiddlewareWrapper(() => {
let first = true;
let reverseController: TransformStreamDefaultController<JsonRpcResponse>;
return {
Expand Down
8 changes: 4 additions & 4 deletions tests/clientRPC/handlers/agentUnlock.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ describe('agentUnlock', () => {
manifest: {
agentUnlock: new AgentUnlockHandler({ logger }),
},
middleware: rpcUtils.defaultMiddlewareWrapper(
middleware: rpcUtils.defaultServerMiddlewareWrapper(
clientRPCUtils.authenticationMiddlewareServer(sessionManager, keyRing),
),
logger,
Expand All @@ -118,11 +118,11 @@ describe('agentUnlock', () => {
logger.getChild('client'),
);
},
middleware: rpcUtils.defaultClientMiddlewareWrapper(
clientRPCUtils.authenticationMiddlewareClient(session),
),
logger,
});
rpcClient.registerMiddleware(
clientRPCUtils.authenticationMiddlewareClient(session),
);

// Doing the test
const result = await rpcClient.methods.agentUnlock({
Expand Down

0 comments on commit ff8e311

Please sign in to comment.