diff --git a/src/muxer.ts b/src/muxer.ts index b1b55be..924c812 100644 --- a/src/muxer.ts +++ b/src/muxer.ts @@ -5,6 +5,8 @@ import type { Source, Sink } from 'it-stream-types' import { WebRTCStream } from './stream.js' import { nopSink, nopSource } from './util.js' +const WEBRTC_PROTOCOL_NAME = '/webrtc' +export interface MessageSizeOption { maxMsgSize?: number } export class DataChannelMuxerFactory implements StreamMuxerFactory { /** * WebRTC Peer Connection @@ -12,12 +14,14 @@ export class DataChannelMuxerFactory implements StreamMuxerFactory { private readonly peerConnection: RTCPeerConnection private streamBuffer: WebRTCStream[] = [] - constructor (peerConnection: RTCPeerConnection, readonly protocol = '/webrtc') { + constructor (peerConnection: RTCPeerConnection, readonly options?: MessageSizeOption, readonly protocol = WEBRTC_PROTOCOL_NAME) { this.peerConnection = peerConnection + this.options = options // store any datachannels opened before upgrade has been completed this.peerConnection.ondatachannel = ({ channel }) => { const stream = new WebRTCStream({ channel, + maxMsgSize: options?.maxMsgSize, stat: { direction: 'inbound', timeline: { open: 0 } @@ -30,11 +34,13 @@ export class DataChannelMuxerFactory implements StreamMuxerFactory { } } - createStreamMuxer (init?: StreamMuxerInit | undefined): StreamMuxer { - return new DataChannelMuxer(this.peerConnection, this.streamBuffer, this.protocol, init) + createStreamMuxer (init?: DataChannelMuxerInit): StreamMuxer { + return new DataChannelMuxer(this.peerConnection, this.streamBuffer, { ...init, ...this.options }, this.protocol) } } +export type DataChannelMuxerInit = StreamMuxerInit & MessageSizeOption + /** * A libp2p data channel stream muxer */ @@ -49,11 +55,6 @@ export class DataChannelMuxer implements StreamMuxer { */ streams: Stream[] = [] - /** - * Initialized stream muxer - */ - init?: StreamMuxerInit - /** * Close or abort all tracked streams and stop the muxer */ @@ -69,12 +70,7 @@ export class DataChannelMuxer implements StreamMuxer { */ sink: Sink> = nopSink - constructor (peerConnection: RTCPeerConnection, streams: Stream[], readonly protocol = '/webrtc', init?: StreamMuxerInit) { - /** - * Initialized stream muxer - */ - this.init = init - + constructor (peerConnection: RTCPeerConnection, streams: Stream[], readonly init: DataChannelMuxerInit, readonly protocol = WEBRTC_PROTOCOL_NAME) { /** * WebRTC Peer Connection */ @@ -89,6 +85,7 @@ export class DataChannelMuxer implements StreamMuxer { this.peerConnection.ondatachannel = ({ channel }) => { const stream = new WebRTCStream({ channel, + maxMsgSize: init?.maxMsgSize, stat: { direction: 'inbound', timeline: { @@ -122,6 +119,7 @@ export class DataChannelMuxer implements StreamMuxer { const channel = this.peerConnection.createDataChannel('') const stream = new WebRTCStream({ channel, + maxMsgSize: this.init?.maxMsgSize, stat: { direction: 'outbound', timeline: { diff --git a/src/peer_transport/handler.ts b/src/peer_transport/handler.ts index 659a847..3b25f29 100644 --- a/src/peer_transport/handler.ts +++ b/src/peer_transport/handler.ts @@ -14,14 +14,14 @@ const DEFAULT_TIMEOUT = 30 * 1000 const log = logger('libp2p:webrtc:peer') -export type IncomingStreamOpts = { rtcConfiguration?: RTCConfiguration } & IncomingStreamData +export type IncomingStreamOpts = { rtcConfiguration?: RTCConfiguration, maxMsgSize?: number } & IncomingStreamData -export async function handleIncomingStream ({ rtcConfiguration, stream: rawStream }: IncomingStreamOpts): Promise<[RTCPeerConnection, StreamMuxerFactory]> { +export async function handleIncomingStream ({ rtcConfiguration, stream: rawStream, maxMsgSize }: IncomingStreamOpts): Promise<[RTCPeerConnection, StreamMuxerFactory]> { const timeoutController = new TimeoutController(DEFAULT_TIMEOUT) const signal = timeoutController.signal const stream = pbStream(abortableDuplex(rawStream, timeoutController.signal)).pb(pb.Message) const pc = new RTCPeerConnection(rtcConfiguration) - const muxerFactory = new DataChannelMuxerFactory(pc) + const muxerFactory = new DataChannelMuxerFactory(pc, { maxMsgSize }) const connectedPromise: DeferredPromise = pDefer() const answerSentPromise: DeferredPromise = pDefer() @@ -85,14 +85,15 @@ export interface ConnectOptions { stream: Stream signal: AbortSignal rtcConfiguration?: RTCConfiguration + maxMsgSize?: number } -export async function initiateConnection ({ rtcConfiguration, signal, stream: rawStream }: ConnectOptions): Promise<[RTCPeerConnection, StreamMuxerFactory]> { +export async function initiateConnection ({ rtcConfiguration, signal, stream: rawStream, maxMsgSize }: ConnectOptions): Promise<[RTCPeerConnection, StreamMuxerFactory]> { const stream = pbStream(abortableDuplex(rawStream, signal)).pb(pb.Message) // setup peer connection const pc = new RTCPeerConnection(rtcConfiguration) - const muxerFactory = new DataChannelMuxerFactory(pc) + const muxerFactory = new DataChannelMuxerFactory(pc, { maxMsgSize }) const connectedPromise: DeferredPromise = pDefer() resolveOnConnected(pc, connectedPromise) diff --git a/src/peer_transport/transport.ts b/src/peer_transport/transport.ts index 98a132b..7c7f185 100644 --- a/src/peer_transport/transport.ts +++ b/src/peer_transport/transport.ts @@ -22,6 +22,7 @@ export const CODE = protocols('webrtc').code export interface WebRTCTransportInit { rtcConfiguration?: RTCConfiguration + maxMsgSize?: number } export interface WebRTCTransportComponents { @@ -125,7 +126,8 @@ export class WebRTCTransport implements Transport, Startable { const [pc, muxerFactory] = await initiateConnection({ stream: rawStream, rtcConfiguration: this.init.rtcConfiguration, - signal: options.signal + signal: options.signal, + maxMsgSize: this.init.maxMsgSize }) const webrtcMultiaddr = baseAddr.encapsulate(`${TRANSPORT}/p2p/${peerId.toString()}`) const result = await options.upgrader.upgradeOutbound( @@ -156,7 +158,8 @@ export class WebRTCTransport implements Transport, Startable { const [pc, muxerFactory] = await handleIncomingStream({ rtcConfiguration: this.init.rtcConfiguration, connection, - stream + stream, + maxMsgSize: this.init.maxMsgSize }) const remotePeerId = connection.remoteAddr.getPeerId() const webrtcMultiaddr = connection.remoteAddr.encapsulate(`${TRANSPORT}/p2p/${remotePeerId}`) diff --git a/src/stream.ts b/src/stream.ts index 2f61de9..b63cf12 100644 --- a/src/stream.ts +++ b/src/stream.ts @@ -48,6 +48,11 @@ interface StreamInitOpts { * Callback to invoke when the stream is closed. */ closeCb?: (stream: WebRTCStream) => void + + /** + * Max allowed message size to be sent + */ + maxMsgSize?: number } /* @@ -152,6 +157,74 @@ class StreamState { } } +interface MessageProcessor { + send: (bytes: Uint8ArrayList) => void + close?: () => void +} + +const createMessageProcessor = (channel: RTCDataChannel, maxMsgSize?: number): MessageProcessor => { + if (maxMsgSize != null) { + /** + * Don't allow channel.bufferedAmount to exceed maxMsgSize + */ + let sendPaused: boolean = false + let sendMessageQueue: Uint8Array[] = [] + const processMessageQueue = (): void => { + sendPaused = false + let message = sendMessageQueue.shift() + while (message != null) { + if (channel.bufferedAmount > maxMsgSize) { + sendPaused = true + sendMessageQueue.unshift(message) + + const listener = (): void => { + channel.removeEventListener('bufferedamountlow', listener) + processMessageQueue() + } + + channel.addEventListener('bufferedamountlow', listener) + return + } + + try { + channel.send(message) + message = sendMessageQueue.shift() + } catch (error: any) { + throw new globalThis.Error(`Error send message, reason: ${error.name} - ${error.message}`) + } + } + } + + return { + send: (sendbuf: Uint8ArrayList) => { + /** + * Don't allow individual messages to exceed maxMsgSize + */ + let from = 0 + let to = Math.min(sendbuf.length, maxMsgSize) + while (to !== from) { + sendMessageQueue.push(sendbuf.subarray(from, to)) + from = to + to = Math.min(to + maxMsgSize, sendbuf.length) + } + if (sendPaused) { + return + } + + processMessageQueue() + }, + + close: () => { + sendMessageQueue = [] + } + } + } else { + return { + send: (sendbuf: Uint8ArrayList) => { channel.send(sendbuf.subarray()) } + } + } +} + export class WebRTCStream implements Stream { /** * Unique identifier for a stream @@ -211,10 +284,14 @@ export class WebRTCStream implements Stream { */ closeCb?: (stream: WebRTCStream) => void + /** + * Processor for messages that allows throttling if necessary + */ + messageProcessor: MessageProcessor + constructor (opts: StreamInitOpts) { this.channel = opts.channel this.id = this.channel.label - this.stat = opts.stat switch (this.channel.readyState) { case 'open': @@ -254,6 +331,8 @@ export class WebRTCStream implements Stream { this.abort(err) } + this.messageProcessor = createMessageProcessor(this.channel, opts.maxMsgSize) + const self = this // reader pipe @@ -311,12 +390,11 @@ export class WebRTCStream implements Stream { const closeWrite = this._closeWriteIterable() for await (const buf of merge(closeWrite, src)) { if (this.streamState.isWriteClosed()) { + this.messageProcessor.close?.() return } const msgbuf = pb.Message.toBinary({ message: buf.subarray() }) - const sendbuf = lengthPrefixed.encode.single(msgbuf) - - this.channel.send(sendbuf.subarray()) + this.messageProcessor.send(lengthPrefixed.encode.single(msgbuf)) } } @@ -435,7 +513,7 @@ export class WebRTCStream implements Stream { try { log.trace('Sending flag: %s', flag.toString()) const msgbuf = pb.Message.toBinary({ flag }) - this.channel.send(lengthPrefixed.encode.single(msgbuf).subarray()) + this.messageProcessor.send(lengthPrefixed.encode.single(msgbuf)) } catch (err) { if (err instanceof Error) { log.error(`Exception while sending flag ${flag}: ${err.message}`) diff --git a/test/stream.spec.ts b/test/stream.spec.ts new file mode 100644 index 0000000..4bcc0b6 --- /dev/null +++ b/test/stream.spec.ts @@ -0,0 +1,30 @@ + +import * as underTest from '../src/stream' +import { expect } from 'aegir/chai' +import { Uint8ArrayList } from 'uint8arraylist' + +const setup = (cb: { send: (bytes: Uint8Array) => void }, maxMsgSize?: number): underTest.WebRTCStream => { + const datachannel = { + readyState: 'open', + send: cb.send + + } + return new underTest.WebRTCStream({ channel: datachannel as RTCDataChannel, stat: underTest.defaultStat('outbound'), maxMsgSize }) +} + +describe('MessageProcessor', () => { + it('handles unconstrained message', () => { + const sent: Uint8Array[] = [] + const webrtcStream = setup({ send: (bytes) => sent.push(bytes) }) + webrtcStream.messageProcessor.send(new Uint8ArrayList(new Uint8Array(1))) + expect(sent).to.deep.equals([new Uint8Array(1)]) + }) + + it('handles bounded by message size', () => { + const sent: Uint8Array[] = [] + const maxMsgSize = 1 + const webrtcStream = setup({ send: (bytes) => sent.push(bytes) }, maxMsgSize) + webrtcStream.messageProcessor.send(new Uint8ArrayList(new Uint8Array(2))) + expect(sent).to.deep.equals([new Uint8Array(1), new Uint8Array(1)]) + }) +})