Skip to content

Commit

Permalink
wip: short circuit path for middleware
Browse files Browse the repository at this point in the history
Related #500
Related #502

[ci skip]
  • Loading branch information
tegefaulkes committed Jan 25, 2023
1 parent 730e712 commit 0e70a4c
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 19 deletions.
27 changes: 21 additions & 6 deletions src/RPC/RPCServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ import type { ReadableWritablePair } from 'stream/web';
import type { JSONValue, POJO } from '../types';
import type { ConnectionInfo } from '../network/types';
import type { RPCErrorEvent } from './utils';
import type { MiddlewareFactory } from 'tokens/types';
import type {
MiddlewareFactory,
MiddlewareForward,
MiddlewareReverse,
} from 'tokens/types';
import { ReadableStream } from 'stream/web';
import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy';
import Logger from '@matrixai/logger';
Expand Down Expand Up @@ -158,9 +162,13 @@ class RPCServer {
let middlewareStream = streamPair.readable.pipeThrough(
new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcRequest),
);
const shortMessageQueue: Array<JsonRpcResponse> = [];
for (const forwardMiddleWareFactory of this.forwardMiddleWare) {
const middleware = forwardMiddleWareFactory();
middlewareStream = middleware(middlewareStream);
middlewareStream = middleware(
middlewareStream,
(value: JsonRpcResponse) => shortMessageQueue.push(value),
);
}
// While ReadableStream can be converted to AsyncIterable, we want it as
// a generator.
Expand Down Expand Up @@ -269,6 +277,7 @@ class RPCServer {
reverseMiddlewareStream = middleware(reverseMiddlewareStream);
}
reverseMiddlewareStream
.pipeThrough(new rpcUtils.QueueMergingTransformStream(shortMessageQueue))
.pipeThrough(new rpcUtils.JsonMessageToJsonStream())
.pipeTo(streamPair.writable)
.catch(() => {});
Expand All @@ -293,15 +302,19 @@ class RPCServer {
}

protected forwardMiddleWare: Array<
MiddlewareFactory<JsonRpcRequest<JSONValue>>
MiddlewareFactory<
MiddlewareForward<JsonRpcRequest<JSONValue>, JsonRpcResponse<JSONValue>>
>
> = [];
protected reverseMiddleware: Array<
MiddlewareFactory<JsonRpcResponse<JSONValue>>
MiddlewareFactory<MiddlewareReverse<JsonRpcResponse<JSONValue>>>
> = [];

@ready(new rpcErrors.ErrorRpcDestroyed())
public registerForwardMiddleware(
middlewareFactory: MiddlewareFactory<JsonRpcRequest<JSONValue>>,
middlewareFactory: MiddlewareFactory<
MiddlewareForward<JsonRpcRequest<JSONValue>, JsonRpcResponse<JSONValue>>
>,
) {
this.forwardMiddleWare.push(middlewareFactory);
}
Expand All @@ -313,7 +326,9 @@ class RPCServer {

@ready(new rpcErrors.ErrorRpcDestroyed())
public registerReverseMiddleware(
middlewareFactory: MiddlewareFactory<JsonRpcResponse<JSONValue>>,
middlewareFactory: MiddlewareFactory<
MiddlewareReverse<JsonRpcResponse<JSONValue>>
>,
) {
this.reverseMiddleware.push(middlewareFactory);
}
Expand Down
7 changes: 6 additions & 1 deletion src/RPC/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ class ErrorRpcRemoteError<T> extends ErrorRpc<T> {
exitCode = sysexits.UNAVAILABLE;
}

class ErrorRpcPlaceholderConnectionError<T> extends ErrorRpc<T> {
class ErrorRpcNoMessageError<T> extends ErrorRpc<T> {
static description = 'For errors not to be conveyed to the client';
}

class ErrorRpcPlaceholderConnectionError<T> extends ErrorRpcNoMessageError<T> {
static description = 'placeholder error for connection stream failure';
exitCode = sysexits.UNAVAILABLE;
}
Expand All @@ -63,5 +67,6 @@ export {
ErrorRpcProtocal,
ErrorRpcMessageLength,
ErrorRpcRemoteError,
ErrorRpcNoMessageError,
ErrorRpcPlaceholderConnectionError,
};
69 changes: 68 additions & 1 deletion src/RPC/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type {
Transformer,
TransformerTransformCallback,
TransformerStartCallback,
TransformerFlushCallback,
} from 'stream/web';
import type {
JsonRpcError,
Expand All @@ -20,6 +21,7 @@ import * as rpcErrors from './errors';
import * as utils from '../utils';
import * as validationErrors from '../validation/errors';
import * as errors from '../errors';
import { promise } from '../utils';
const jsonStreamParsers = require('@streamparser/json');

class JsonToJsonMessage<T extends JsonRpcMessage>
Expand Down Expand Up @@ -479,7 +481,7 @@ class ClientOutputTransformerStream<
}

function isReturnableError(e: Error): boolean {
if (e instanceof rpcErrors.ErrorRpcPlaceholderConnectionError) return false;
if (e instanceof rpcErrors.ErrorRpcNoMessageError) return false;
return true;
}

Expand All @@ -500,6 +502,69 @@ class RPCErrorEvent extends Event {
}
}

const controllerTransformationFactory = <T>() => {
const controllerProm = promise<TransformStreamDefaultController<T>>();

class ControllerTransform<T> implements Transformer<T, T> {
start: TransformerStartCallback<T> = async (controller) => {
// @ts-ignore: type mismatch oddity
controllerProm.resolveP(controller);
};

transform: TransformerTransformCallback<T, T> = async (
chunk,
controller,
) => {
controller.enqueue(chunk);
};
}

class ControllerTransformStream<T> extends TransformStream<T, T> {
constructor() {
super(new ControllerTransform());
}
}
return {
controllerP: controllerProm.p,
controllerTransformStream: new ControllerTransformStream<T>(),
};
};

class QueueMergingTransform<T> implements Transformer<T, T> {
constructor(protected messageQueue: Array<T>) {}

start: TransformerStartCallback<T> = async (controller) => {
while (true) {
const value = this.messageQueue.shift();
if (value == null) break;
controller.enqueue(value);
}
};

transform: TransformerTransformCallback<T, T> = async (chunk, controller) => {
while (true) {
const value = this.messageQueue.shift();
if (value == null) break;
controller.enqueue(value);
}
controller.enqueue(chunk);
};

flush: TransformerFlushCallback<T> = (controller) => {
while (true) {
const value = this.messageQueue.shift();
if (value == null) break;
controller.enqueue(value);
}
};
}

class QueueMergingTransformStream<T> extends TransformStream<T, T> {
constructor(messageQueue: Array<T>) {
super(new QueueMergingTransform(messageQueue));
}
}

export {
JsonToJsonMessageStream,
JsonMessageToJsonStream,
Expand All @@ -516,4 +581,6 @@ export {
ClientOutputTransformerStream,
isReturnableError,
RPCErrorEvent,
controllerTransformationFactory,
QueueMergingTransformStream,
};
11 changes: 8 additions & 3 deletions src/tokens/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,12 @@ type SignedTokenEncoded = {
signatures: Array<TokenHeaderSignatureEncoded>;
};

type Middleware<T> = (input: ReadableStream<T>) => ReadableStream<T>;
type MiddlewareFactory<T> = () => Middleware<T>;
type MiddlewareForward<T, K> = (
input: ReadableStream<T>,
short: (value: K) => void,
) => ReadableStream<T>;
type MiddlewareReverse<T> = (input: ReadableStream<T>) => ReadableStream<T>;
type MiddlewareFactory<T> = () => T;

export type {
TokenPayload,
Expand All @@ -136,6 +140,7 @@ export type {
SignedToken,
SignedTokenJSON,
SignedTokenEncoded,
Middleware,
MiddlewareForward,
MiddlewareReverse,
MiddlewareFactory,
};
28 changes: 20 additions & 8 deletions tests/RPC/RPCServer.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -498,16 +498,16 @@ describe(`${RPCServer.name}`, () => {
const invalidTokenMessageArb = rpcTestUtils.jsonRpcRequestMessageArb(
undefined,
fc.record({
metadata: fc.constant({
token: validToken,
metadata: fc.record({
token: fc.string().filter((v) => v !== validToken),
}),
data: rpcTestUtils.safeJsonValueArb,
}),
);

// FIXME: This is a test for the authentication.
// It needs short circuit on the middleware.
testProp(
testProp.only(
'forward middleware authentication',
[invalidTokenMessageArb],
async (message) => {
Expand Down Expand Up @@ -539,13 +539,24 @@ describe(`${RPCServer.name}`, () => {
implements
Transformer<JsonRpcRequest<TestType>, JsonRpcRequest<TestType>>
{
constructor(
protected short: (value: JsonRpcResponse<JSONValue>) => void,
) {}
first = true;
transform: TransformerTransformCallback<
JsonRpcRequest<TestType>,
JsonRpcRequest<TestType>
> = async (chunk, controller) => {
if (this.first && chunk.params?.metadata.token !== validToken) {
controller.error(Error('test'));
this.short({
jsonrpc: '2.0',
id: null,
error: {
code: 1,
message: 'failure of somekind',
},
});
controller.error(new rpcErrors.ErrorRpcNoMessageError());
}
this.first = false;
controller.enqueue(chunk);
Expand All @@ -555,19 +566,20 @@ describe(`${RPCServer.name}`, () => {
JsonRpcRequest<TestType>,
JsonRpcRequest<TestType>
> {
constructor() {
super(new AuthenticationTransformer());
constructor(short: (value: JsonRpcResponse<JSONValue>) => void) {
super(new AuthenticationTransformer(short));
}
}
rpcServer.registerForwardMiddleware(() => {
return (input) => {
return input.pipeThrough(new AuthenticationTransformerStream());
return (input, short) => {
return input.pipeThrough(new AuthenticationTransformerStream(short));
};
});
rpcServer.handleStream(readWriteStream, {} as ConnectionInfo);
await outputResult;
await rpcServer.destroy();
},
{ numRuns: 1 },
);

// TODO:
Expand Down
31 changes: 31 additions & 0 deletions tests/RPC/utils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,37 @@ describe('utils tests', () => {
{ numRuns: 1000 },
);

testProp(
'merging transformation stream',
[fc.array(fc.integer()), fc.array(fc.integer())],
async (set1, set2) => {
const [outputResult, outputWriterStream] =
rpcTestUtils.streamToArray<number>();
const { controllerP, controllerTransformStream } =
rpcUtils.controllerTransformationFactory<number>();
void controllerTransformStream.readable
.pipeTo(outputWriterStream)
.catch(() => {});
const writer = controllerTransformStream.writable.getWriter();
const controller = await controllerP;
const expectedResult: Array<number> = [];
for (let i = 0; i < Math.max(set1.length, set2.length); i++) {
if (set1[i] != null) {
await writer.write(set1[i]);
expectedResult.push(set1[i]);
}
if (set2[i] != null) {
controller.enqueue(set2[i]);
expectedResult.push(set2[i]);
}
}
await writer.close();

expect(await outputResult).toStrictEqual(expectedResult);
},
{ numRuns: 1000 },
);

// TODO:
// - Test for badly structured data
});

0 comments on commit 0e70a4c

Please sign in to comment.