Skip to content

Commit

Permalink
feat: middleware
Browse files Browse the repository at this point in the history
Related #502
Related #500
Related #502

[ci skip]
  • Loading branch information
tegefaulkes committed Feb 14, 2023
1 parent 260f23b commit 74450a9
Show file tree
Hide file tree
Showing 8 changed files with 599 additions and 81 deletions.
61 changes: 55 additions & 6 deletions src/RPC/RPCClient.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import type { StreamPairCreateCallback } from './types';
import type { JSONValue, POJO } from 'types';
import type { ReadableWritablePair } from 'stream/web';
import type {
JsonRpcRequest,
JsonRpcResponse,
MiddlewareFactory,
Middleware,
} from './types';
import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy';
import Logger from '@matrixai/logger';
import * as rpcErrors from './errors';
Expand Down Expand Up @@ -50,14 +56,24 @@ class RPCClient {
_metadata: POJO,
): Promise<ReadableWritablePair<O, I>> {
const streamPair = await this.streamPairCreateCallback();
const outputStream = streamPair.readable
.pipeThrough(
new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcResponse),
)
.pipeThrough(new rpcUtils.ClientOutputTransformerStream<O>());
let reverseMiddlewareStream = streamPair.readable.pipeThrough(
new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcResponse),
);
for (const middleWare of this.reverseMiddleware) {
const middle = middleWare();
reverseMiddlewareStream = middle(reverseMiddlewareStream);
}
const outputStream = reverseMiddlewareStream.pipeThrough(
new rpcUtils.ClientOutputTransformerStream<O>(),
);
const inputMessageTransformer =
new rpcUtils.ClientInputTransformerStream<I>(method);
void inputMessageTransformer.readable
let forwardMiddlewareStream = inputMessageTransformer.readable;
for (const middleware of this.forwardMiddleWare) {
const middle = middleware();
forwardMiddlewareStream = middle(forwardMiddlewareStream);
}
void forwardMiddlewareStream
.pipeThrough(new rpcUtils.JsonMessageToJsonStream())
.pipeTo(streamPair.writable)
.catch(() => {});
Expand Down Expand Up @@ -188,6 +204,39 @@ class RPCClient {
await writer.close();
return callerInterface.output;
}

protected forwardMiddleWare: Array<
MiddlewareFactory<Middleware<JsonRpcRequest<JSONValue>>>
> = [];
protected reverseMiddleware: Array<
MiddlewareFactory<Middleware<JsonRpcResponse<JSONValue>>>
> = [];

@ready(new rpcErrors.ErrorRpcDestroyed())
public registerForwardMiddleware(
middlewareFactory: MiddlewareFactory<Middleware<JsonRpcRequest<JSONValue>>>,
) {
this.forwardMiddleWare.push(middlewareFactory);
}

@ready(new rpcErrors.ErrorRpcDestroyed())
public clearForwardMiddleware() {
this.reverseMiddleware = [];
}

@ready(new rpcErrors.ErrorRpcDestroyed())
public registerReverseMiddleware(
middlewareFactory: MiddlewareFactory<
Middleware<JsonRpcResponse<JSONValue>>
>,
) {
this.reverseMiddleware.push(middlewareFactory);
}

@ready(new rpcErrors.ErrorRpcDestroyed())
public clearReverseMiddleware() {
this.reverseMiddleware = [];
}
}

export default RPCClient;
177 changes: 121 additions & 56 deletions src/RPC/RPCServer.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
import type {
ServerStreamHandler,
ClientStreamHandler,
DuplexStreamHandler,
JsonRpcError,
JsonRpcMessage,
JsonRpcRequest,
JsonRpcResponse,
JsonRpcResponseError,
JsonRpcResponseResult,
ClientStreamHandler,
ServerStreamHandler,
UnaryHandler,
} from './types';
import type { ReadableWritablePair } from 'stream/web';
import type { JSONValue, POJO } from '../types';
import type { ConnectionInfo } from '../network/types';
import type { UnaryHandler } from './types';
import type { RPCErrorEvent } from './utils';
import type {
MiddlewareFactory,
MiddlewareShort,
Middleware,
} from 'tokens/types';
import { ReadableStream } from 'stream/web';
import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy';
import Logger from '@matrixai/logger';
import { PromiseCancellable } from '@matrixai/async-cancellable';
import * as rpcErrors from './errors';
import * as rpcUtils from './utils';
import * as rpcErrors from './errors';

interface RPCServer extends CreateDestroy {}
@CreateDestroy()
Expand Down Expand Up @@ -152,21 +158,32 @@ class RPCServer {
void handlerProm
.finally(() => this.activeStreams.delete(handlerProm))
.catch(() => {});
// Setting up forward middleware
let middlewareStream = streamPair.readable.pipeThrough(
new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcRequest),
);
const shortMessageQueue: Array<JsonRpcResponse> = [];
for (const forwardMiddleWareFactory of this.forwardMiddleWare) {
const middleware = forwardMiddleWareFactory();
middlewareStream = middleware(
middlewareStream,
(value: JsonRpcResponse) => shortMessageQueue.push(value),
);
}
// While ReadableStream can be converted to AsyncIterable, we want it as
// a generator.
const inputGen = async function* () {
const pojoStream = streamPair.readable.pipeThrough(
new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcRequest),
);
for await (const dataMessage of pojoStream) {
for await (const dataMessage of middlewareStream) {
yield dataMessage;
}
};
const container = this.container;
const handlerMap = this.handlerMap;
const ctx = { signal: abortController.signal };
const events = this.events;
const outputGen = async function* (): AsyncGenerator<JsonRpcMessage> {
const outputGen = async function* (): AsyncGenerator<
JsonRpcResponse<JSONValue>
> {
// Step 1, authentication and establishment
// read the first message, lets assume the first message is always leading
// metadata.
Expand All @@ -184,7 +201,6 @@ class RPCServer {
yield data.params as JSONValue;
}
};
// TODO: validation on metadata
const handler = handlerMap.get(method);
if (handler == null) {
// Failed to find handler, this is an error. We should respond with
Expand All @@ -194,69 +210,80 @@ class RPCServer {
);
}
if (ctx.signal.aborted) throw ctx.signal.reason;
try {
for await (const response of handler(
dataGen(),
container,
connectionInfo,
ctx,
)) {
const responseMessage: JsonRpcResponseResult<JSONValue> = {
jsonrpc: '2.0',
result: response,
id: null,
};
yield responseMessage;
}
} catch (e) {
if (rpcUtils.isReturnableError(e)) {
// We want to convert this error to an error message and pass it along
const rpcError: JsonRpcError = {
code: e.exitCode,
message: e.description,
data: rpcUtils.fromError(e),
};
const rpcErrorMessage: JsonRpcResponseError = {
jsonrpc: '2.0',
error: rpcError,
id: null,
};
yield rpcErrorMessage;
} else {
// These errors are emitted to the event system
events.dispatchEvent(
new rpcUtils.RPCErrorEvent({
detail: {
error: e,
},
}),
);
}
for await (const response of handler(
dataGen(),
container,
connectionInfo,
ctx,
)) {
const responseMessage: JsonRpcResponseResult<JSONValue> = {
jsonrpc: '2.0',
result: response,
id: null,
};
yield responseMessage;
}
resolve();
};

const outputGenerator = outputGen();

const outputStream = new ReadableStream<JsonRpcMessage>({
let reverseMiddlewareStream = new ReadableStream<
JsonRpcResponse<JSONValue>
>({
pull: async (controller) => {
const { value, done } = await outputGenerator.next();
if (done) {
try {
const { value, done } = await outputGenerator.next();
if (done) {
controller.close();
resolve();
return;
}
controller.enqueue(value);
} catch (e) {
if (rpcUtils.isReturnableError(e)) {
// We want to convert this error to an error message and pass it along
const rpcError: JsonRpcError = {
code: e.exitCode,
message: e.description,
data: rpcUtils.fromError(e),
};
const rpcErrorMessage: JsonRpcResponseError = {
jsonrpc: '2.0',
error: rpcError,
id: null,
};
controller.enqueue(rpcErrorMessage);
} else {
// These errors are emitted to the event system
events.dispatchEvent(
new rpcUtils.RPCErrorEvent({
detail: {
error: e,
},
}),
);
}
controller.close();
return;
resolve();
}
controller.enqueue(value);
},
cancel: async (reason) => {
await outputGenerator.throw(reason);
},
});
void outputStream
// Setting up reverse middleware
for (const reverseMiddleWareFactory of this.reverseMiddleware) {
const middleware = reverseMiddleWareFactory();
reverseMiddlewareStream = middleware(reverseMiddlewareStream);
}
reverseMiddlewareStream
.pipeThrough(new rpcUtils.QueueMergingTransformStream(shortMessageQueue))
.pipeThrough(new rpcUtils.JsonMessageToJsonStream())
.pipeTo(streamPair.writable)
.catch(() => {});
}

@ready(new rpcErrors.ErrorRpcDestroyed())
public addEventListener(
type: 'error',
callback: (event: RPCErrorEvent) => void,
Expand All @@ -265,13 +292,51 @@ class RPCServer {
this.events.addEventListener(type, callback, options);
}

@ready(new rpcErrors.ErrorRpcDestroyed())
public removeEventListener(
type: 'error',
callback: (event: RPCErrorEvent) => void,
options?: boolean | AddEventListenerOptions | undefined,
) {
this.events.removeEventListener(type, callback, options);
}

protected forwardMiddleWare: Array<
MiddlewareFactory<
MiddlewareShort<JsonRpcRequest<JSONValue>, JsonRpcResponse<JSONValue>>
>
> = [];
protected reverseMiddleware: Array<
MiddlewareFactory<Middleware<JsonRpcResponse<JSONValue>>>
> = [];

@ready(new rpcErrors.ErrorRpcDestroyed())
public registerForwardMiddleware(
middlewareFactory: MiddlewareFactory<
MiddlewareShort<JsonRpcRequest<JSONValue>, JsonRpcResponse<JSONValue>>
>,
) {
this.forwardMiddleWare.push(middlewareFactory);
}

@ready(new rpcErrors.ErrorRpcDestroyed())
public clearForwardMiddleware() {
this.reverseMiddleware = [];
}

@ready(new rpcErrors.ErrorRpcDestroyed())
public registerReverseMiddleware(
middlewareFactory: MiddlewareFactory<
Middleware<JsonRpcResponse<JSONValue>>
>,
) {
this.reverseMiddleware.push(middlewareFactory);
}

@ready(new rpcErrors.ErrorRpcDestroyed())
public clearReverseMiddleware() {
this.reverseMiddleware = [];
}
}

export default RPCServer;
7 changes: 6 additions & 1 deletion src/RPC/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ class ErrorRpcRemoteError<T> extends ErrorRpc<T> {
exitCode = sysexits.UNAVAILABLE;
}

class ErrorRpcPlaceholderConnectionError<T> extends ErrorRpc<T> {
class ErrorRpcNoMessageError<T> extends ErrorRpc<T> {
static description = 'For errors not to be conveyed to the client';
}

class ErrorRpcPlaceholderConnectionError<T> extends ErrorRpcNoMessageError<T> {
static description = 'placeholder error for connection stream failure';
exitCode = sysexits.UNAVAILABLE;
}
Expand All @@ -63,5 +67,6 @@ export {
ErrorRpcProtocal,
ErrorRpcMessageLength,
ErrorRpcRemoteError,
ErrorRpcNoMessageError,
ErrorRpcPlaceholderConnectionError,
};
11 changes: 11 additions & 0 deletions src/RPC/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { JSONValue, POJO } from '../types';
import type { ConnectionInfo } from '../network/types';
import type { ContextCancellable } from '../contexts/types';
import type { ReadableWritablePair } from 'stream/web';
import type { ReadableStream } from 'stream/web';

/**
* This is the JSON RPC request object. this is the generic message type used for the RPC.
Expand Down Expand Up @@ -127,6 +128,13 @@ type StreamPairCreateCallback = () => Promise<
ReadableWritablePair<Uint8Array, Uint8Array>
>;

type MiddlewareShort<T, K> = (
input: ReadableStream<T>,
short: (value: K) => void,
) => ReadableStream<T>;
type Middleware<T> = (input: ReadableStream<T>) => ReadableStream<T>;
type MiddlewareFactory<T> = () => T;

export type {
JsonRpcRequestMessage,
JsonRpcRequestNotification,
Expand All @@ -141,4 +149,7 @@ export type {
ClientStreamHandler,
UnaryHandler,
StreamPairCreateCallback,
MiddlewareShort,
Middleware,
MiddlewareFactory,
};
Loading

0 comments on commit 74450a9

Please sign in to comment.