From 279ad47517ae3d4bc99ab499bf1fd9ef67dbb74b Mon Sep 17 00:00:00 2001 From: Alex Potsides Date: Thu, 24 Nov 2022 09:48:46 +0000 Subject: [PATCH] fix: apply message size limit before decoding message (#231) * fix: apply message size limit before decoding message If we apply the message size limit after decoding the message it's too late as we've already processed the bad message. Instead, if the buffer full of unprocessed messages grows to be large than the max message size (e.g. we have not recieved a complete message under the size limit), throw an error which will cause the stream to be reset. * fix: add implementation --- src/decode.ts | 25 ++++++++++++++++++------- src/mplex.ts | 4 +--- src/restrict-size.ts | 36 ------------------------------------ src/stream.ts | 2 +- test/coder.spec.ts | 15 ++++++--------- test/mplex.spec.ts | 16 ++++++++-------- test/restrict-size.spec.ts | 15 +++++++++------ 7 files changed, 43 insertions(+), 70 deletions(-) delete mode 100644 src/restrict-size.ts diff --git a/src/decode.ts b/src/decode.ts index 94aaf40cad..6c4e5ab276 100644 --- a/src/decode.ts +++ b/src/decode.ts @@ -3,6 +3,8 @@ import { Uint8ArrayList } from 'uint8arraylist' import type { Source } from 'it-stream-types' import type { Message } from './message-types.js' +export const MAX_MSG_SIZE = 1 << 20 // 1MB + interface MessageHeader { id: number type: keyof typeof MessageTypeNames @@ -13,10 +15,12 @@ interface MessageHeader { class Decoder { private readonly _buffer: Uint8ArrayList private _headerInfo: MessageHeader | null + private readonly _maxMessageSize: number - constructor () { + constructor (maxMessageSize: number = MAX_MSG_SIZE) { this._buffer = new Uint8ArrayList() this._headerInfo = null + this._maxMessageSize = maxMessageSize } write (chunk: Uint8Array) { @@ -25,6 +29,11 @@ class Decoder { } this._buffer.append(chunk) + + if (this._buffer.byteLength > this._maxMessageSize) { + throw Object.assign(new Error('message size too large!'), { code: 'ERR_MSG_TOO_BIG' }) + } + const msgs: Message[] = [] while (this._buffer.length !== 0) { @@ -119,14 +128,16 @@ function readVarInt (buf: Uint8ArrayList, offset: number = 0) { /** * Decode a chunk and yield an _array_ of decoded messages */ -export async function * decode (source: Source) { - const decoder = new Decoder() +export function decode (maxMessageSize: number = MAX_MSG_SIZE) { + return async function * decodeMessages (source: Source): Source { + const decoder = new Decoder(maxMessageSize) - for await (const chunk of source) { - const msgs = decoder.write(chunk) + for await (const chunk of source) { + const msgs = decoder.write(chunk) - if (msgs.length > 0) { - yield msgs + if (msgs.length > 0) { + yield * msgs + } } } } diff --git a/src/mplex.ts b/src/mplex.ts index 2d44172d97..945a64f3d5 100644 --- a/src/mplex.ts +++ b/src/mplex.ts @@ -3,7 +3,6 @@ import { pushableV } from 'it-pushable' import { abortableSource } from 'abortable-iterator' import { encode } from './encode.js' import { decode } from './decode.js' -import { restrictSize } from './restrict-size.js' import { MessageTypes, MessageTypeNames, Message } from './message-types.js' import { createStream } from './stream.js' import { toString as uint8ArrayToString } from 'uint8arrays' @@ -204,8 +203,7 @@ export class MplexStreamMuxer implements StreamMuxer { try { await pipe( source, - decode, - restrictSize(this._init.maxMsgSize), + decode(this._init.maxMsgSize), async source => { for await (const msg of source) { await this._handleIncoming(msg) diff --git a/src/restrict-size.ts b/src/restrict-size.ts deleted file mode 100644 index e91d9c9af4..0000000000 --- a/src/restrict-size.ts +++ /dev/null @@ -1,36 +0,0 @@ -import { Message, MessageTypes } from './message-types.js' -import type { Source, Transform } from 'it-stream-types' - -export const MAX_MSG_SIZE = 1 << 20 // 1MB - -/** - * Creates an iterable transform that restricts message sizes to - * the given maximum size. - */ -export function restrictSize (max?: number): Transform { - const maxSize = max ?? MAX_MSG_SIZE - - const checkSize = (msg: Message) => { - if (msg.type !== MessageTypes.NEW_STREAM && msg.type !== MessageTypes.MESSAGE_INITIATOR && msg.type !== MessageTypes.MESSAGE_RECEIVER) { - return - } - - if (msg.data.byteLength > maxSize) { - throw Object.assign(new Error('message size too large!'), { code: 'ERR_MSG_TOO_BIG' }) - } - } - - return (source: Source) => { - return (async function * restrictSize () { - for await (const msg of source) { - if (Array.isArray(msg)) { - msg.forEach(checkSize) - yield * msg - } else { - checkSize(msg) - yield msg - } - } - })() - } -} diff --git a/src/stream.ts b/src/stream.ts index bfc29cfac3..1fd9e5f996 100644 --- a/src/stream.ts +++ b/src/stream.ts @@ -1,7 +1,7 @@ import { abortableSource } from 'abortable-iterator' import { pushable } from 'it-pushable' import errCode from 'err-code' -import { MAX_MSG_SIZE } from './restrict-size.js' +import { MAX_MSG_SIZE } from './decode.js' import { anySignal } from 'any-signal' import { InitiatorMessageTypes, ReceiverMessageTypes } from './message-types.js' import { fromString as uint8ArrayFromString } from 'uint8arrays/from-string' diff --git a/test/coder.spec.ts b/test/coder.spec.ts index 9a96f5a98d..45f1264c75 100644 --- a/test/coder.spec.ts +++ b/test/coder.spec.ts @@ -23,10 +23,8 @@ describe('coder', () => { it('should decode header', async () => { const source = [uint8ArrayFromString('8801023137', 'base16')] - for await (const msgs of decode(source)) { - expect(msgs.length).to.equal(1) - - expect(messageWithBytes(msgs[0])).to.be.deep.equal({ id: 17, type: 0, data: uint8ArrayFromString('17') }) + for await (const msg of decode()(source)) { + expect(messageWithBytes(msg)).to.be.deep.equal({ id: 17, type: 0, data: uint8ArrayFromString('17') }) } }) @@ -67,8 +65,8 @@ describe('coder', () => { const source = [uint8ArrayFromString('88010231379801023139a801023231', 'base16')] const res = [] - for await (const msgs of decode(source)) { - res.push(...msgs) + for await (const msg of decode()(source)) { + res.push(msg) } expect(res.map(messageWithBytes)).to.deep.equal([ @@ -89,9 +87,8 @@ describe('coder', () => { it('should decode zero length body msg', async () => { const source = [uint8ArrayFromString('880100', 'base16')] - for await (const msgs of decode(source)) { - expect(msgs.length).to.equal(1) - expect(messageWithBytes(msgs[0])).to.be.eql({ id: 17, type: 0, data: new Uint8Array(0) }) + for await (const msg of decode()(source)) { + expect(messageWithBytes(msg)).to.be.eql({ id: 17, type: 0, data: new Uint8Array(0) }) } }) }) diff --git a/test/mplex.spec.ts b/test/mplex.spec.ts index 4032f871ce..f2e58b2835 100644 --- a/test/mplex.spec.ts +++ b/test/mplex.spec.ts @@ -76,10 +76,10 @@ describe('mplex', () => { await muxer.sink(stream) - const messages = await all(decode(bufs)) + const messages = await all(decode()(bufs)) - expect(messages).to.have.nested.property('[0][0].id', 11, 'Did not specify the correct stream id') - expect(messages).to.have.nested.property('[0][0].type', MessageTypes.RESET_RECEIVER, 'Did not reset the stream that tipped us over the inbound stream limit') + expect(messages).to.have.nested.property('[0].id', 11, 'Did not specify the correct stream id') + expect(messages).to.have.nested.property('[0].type', MessageTypes.RESET_RECEIVER, 'Did not reset the stream that tipped us over the inbound stream limit') }) it('should reset a stream that fills the message buffer', async () => { @@ -103,7 +103,7 @@ describe('mplex', () => { const dataMessage: MessageInitiatorMessage = { id, type: MessageTypes.MESSAGE_INITIATOR, - data: new Uint8ArrayList(new Uint8Array(1024 * 1024)) + data: new Uint8ArrayList(new Uint8Array(1024 * 1000)) } yield dataMessage @@ -144,9 +144,9 @@ describe('mplex', () => { // collect outgoing mplex messages const muxerFinished = pDefer() - let messages: Message[][] = [] + let messages: Message[] = [] void Promise.resolve().then(async () => { - messages = await all(decode(muxer.source)) + messages = await all(decode()(muxer.source)) muxerFinished.resolve() }) @@ -159,7 +159,7 @@ describe('mplex', () => { // should have sent reset message to peer for this stream await muxerFinished.promise - expect(messages).to.have.nested.property('[0][0].id', id) - expect(messages).to.have.nested.property('[0][0].type', MessageTypes.RESET_RECEIVER) + expect(messages).to.have.nested.property('[0].id', id) + expect(messages).to.have.nested.property('[0].type', MessageTypes.RESET_RECEIVER) }) }) diff --git a/test/restrict-size.spec.ts b/test/restrict-size.spec.ts index b744207825..aacb620cda 100644 --- a/test/restrict-size.spec.ts +++ b/test/restrict-size.spec.ts @@ -7,18 +7,19 @@ import all from 'it-all' import drain from 'it-drain' import each from 'it-foreach' import { Message, MessageTypes } from '../src/message-types.js' -import { restrictSize } from '../src/restrict-size.js' +import { encode } from '../src/encode.js' +import { decode } from '../src/decode.js' import { Uint8ArrayList } from 'uint8arraylist' -describe('restrict-size', () => { +describe('restrict size', () => { it('should throw when size is too big', async () => { const maxSize = 32 const input: Message[] = [ { id: 0, type: 1, data: new Uint8ArrayList(randomBytes(8)) }, + { id: 0, type: 1, data: new Uint8ArrayList(randomBytes(16)) }, { id: 0, type: 1, data: new Uint8ArrayList(randomBytes(maxSize)) }, - { id: 0, type: 1, data: new Uint8ArrayList(randomBytes(64)) }, - { id: 0, type: 1, data: new Uint8ArrayList(randomBytes(16)) } + { id: 0, type: 1, data: new Uint8ArrayList(randomBytes(64)) } ] const output: Message[] = [] @@ -26,7 +27,8 @@ describe('restrict-size', () => { try { await pipe( input, - restrictSize(maxSize), + encode, + decode(maxSize), (source) => each(source, chunk => { output.push(chunk) }), @@ -51,7 +53,8 @@ describe('restrict-size', () => { const output = await pipe( input, - restrictSize(32), + encode, + decode(32), async (source) => await all(source) ) expect(output).to.deep.equal(input)