Skip to content
This repository has been archived by the owner on Jun 19, 2023. It is now read-only.

Commit

Permalink
add maxMsgSize option
Browse files Browse the repository at this point in the history
  • Loading branch information
marcus-pousette committed Apr 28, 2023
1 parent 2f8faa3 commit f7b8a12
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 26 deletions.
26 changes: 12 additions & 14 deletions src/muxer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,23 @@ import type { Source, Sink } from 'it-stream-types'
import { WebRTCStream } from './stream.js'
import { nopSink, nopSource } from './util.js'

const WEBRTC_PROTOCOL_NAME = '/webrtc'
export interface MessageSizeOption { maxMsgSize?: number }
export class DataChannelMuxerFactory implements StreamMuxerFactory {
/**
* WebRTC Peer Connection
*/
private readonly peerConnection: RTCPeerConnection
private streamBuffer: WebRTCStream[] = []

constructor (peerConnection: RTCPeerConnection, readonly protocol = '/webrtc') {
constructor (peerConnection: RTCPeerConnection, readonly options?: MessageSizeOption, readonly protocol = WEBRTC_PROTOCOL_NAME) {
this.peerConnection = peerConnection
this.options = options
// store any datachannels opened before upgrade has been completed
this.peerConnection.ondatachannel = ({ channel }) => {
const stream = new WebRTCStream({
channel,
maxMsgSize: options?.maxMsgSize,
stat: {
direction: 'inbound',
timeline: { open: 0 }
Expand All @@ -30,11 +34,13 @@ export class DataChannelMuxerFactory implements StreamMuxerFactory {
}
}

createStreamMuxer (init?: StreamMuxerInit | undefined): StreamMuxer {
return new DataChannelMuxer(this.peerConnection, this.streamBuffer, this.protocol, init)
createStreamMuxer (init?: DataChannelMuxerInit): StreamMuxer {
return new DataChannelMuxer(this.peerConnection, this.streamBuffer, { ...init, ...this.options }, this.protocol)
}
}

export type DataChannelMuxerInit = StreamMuxerInit & MessageSizeOption

/**
* A libp2p data channel stream muxer
*/
Expand All @@ -49,11 +55,6 @@ export class DataChannelMuxer implements StreamMuxer {
*/
streams: Stream[] = []

/**
* Initialized stream muxer
*/
init?: StreamMuxerInit

/**
* Close or abort all tracked streams and stop the muxer
*/
Expand All @@ -69,12 +70,7 @@ export class DataChannelMuxer implements StreamMuxer {
*/
sink: Sink<Uint8Array, Promise<void>> = nopSink

constructor (peerConnection: RTCPeerConnection, streams: Stream[], readonly protocol = '/webrtc', init?: StreamMuxerInit) {
/**
* Initialized stream muxer
*/
this.init = init

constructor (peerConnection: RTCPeerConnection, streams: Stream[], readonly init: DataChannelMuxerInit, readonly protocol = WEBRTC_PROTOCOL_NAME) {
/**
* WebRTC Peer Connection
*/
Expand All @@ -89,6 +85,7 @@ export class DataChannelMuxer implements StreamMuxer {
this.peerConnection.ondatachannel = ({ channel }) => {
const stream = new WebRTCStream({
channel,
maxMsgSize: init?.maxMsgSize,
stat: {
direction: 'inbound',
timeline: {
Expand Down Expand Up @@ -122,6 +119,7 @@ export class DataChannelMuxer implements StreamMuxer {
const channel = this.peerConnection.createDataChannel('')
const stream = new WebRTCStream({
channel,
maxMsgSize: this.init?.maxMsgSize,
stat: {
direction: 'outbound',
timeline: {
Expand Down
11 changes: 6 additions & 5 deletions src/peer_transport/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ const DEFAULT_TIMEOUT = 30 * 1000

const log = logger('libp2p:webrtc:peer')

export type IncomingStreamOpts = { rtcConfiguration?: RTCConfiguration } & IncomingStreamData
export type IncomingStreamOpts = { rtcConfiguration?: RTCConfiguration, maxMsgSize?: number } & IncomingStreamData

export async function handleIncomingStream ({ rtcConfiguration, stream: rawStream }: IncomingStreamOpts): Promise<[RTCPeerConnection, StreamMuxerFactory]> {
export async function handleIncomingStream ({ rtcConfiguration, stream: rawStream, maxMsgSize }: IncomingStreamOpts): Promise<[RTCPeerConnection, StreamMuxerFactory]> {
const timeoutController = new TimeoutController(DEFAULT_TIMEOUT)
const signal = timeoutController.signal
const stream = pbStream(abortableDuplex(rawStream, timeoutController.signal)).pb(pb.Message)
const pc = new RTCPeerConnection(rtcConfiguration)
const muxerFactory = new DataChannelMuxerFactory(pc)
const muxerFactory = new DataChannelMuxerFactory(pc, { maxMsgSize })

const connectedPromise: DeferredPromise<void> = pDefer()
const answerSentPromise: DeferredPromise<void> = pDefer()
Expand Down Expand Up @@ -85,14 +85,15 @@ export interface ConnectOptions {
stream: Stream
signal: AbortSignal
rtcConfiguration?: RTCConfiguration
maxMsgSize?: number
}

export async function initiateConnection ({ rtcConfiguration, signal, stream: rawStream }: ConnectOptions): Promise<[RTCPeerConnection, StreamMuxerFactory]> {
export async function initiateConnection ({ rtcConfiguration, signal, stream: rawStream, maxMsgSize }: ConnectOptions): Promise<[RTCPeerConnection, StreamMuxerFactory]> {
const stream = pbStream(abortableDuplex(rawStream, signal)).pb(pb.Message)

// setup peer connection
const pc = new RTCPeerConnection(rtcConfiguration)
const muxerFactory = new DataChannelMuxerFactory(pc)
const muxerFactory = new DataChannelMuxerFactory(pc, { maxMsgSize })

const connectedPromise: DeferredPromise<void> = pDefer()
resolveOnConnected(pc, connectedPromise)
Expand Down
7 changes: 5 additions & 2 deletions src/peer_transport/transport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export const CODE = protocols('webrtc').code

export interface WebRTCTransportInit {
rtcConfiguration?: RTCConfiguration
maxMsgSize?: number
}

export interface WebRTCTransportComponents {
Expand Down Expand Up @@ -125,7 +126,8 @@ export class WebRTCTransport implements Transport, Startable {
const [pc, muxerFactory] = await initiateConnection({
stream: rawStream,
rtcConfiguration: this.init.rtcConfiguration,
signal: options.signal
signal: options.signal,
maxMsgSize: this.init.maxMsgSize
})
const webrtcMultiaddr = baseAddr.encapsulate(`${TRANSPORT}/p2p/${peerId.toString()}`)
const result = await options.upgrader.upgradeOutbound(
Expand Down Expand Up @@ -156,7 +158,8 @@ export class WebRTCTransport implements Transport, Startable {
const [pc, muxerFactory] = await handleIncomingStream({
rtcConfiguration: this.init.rtcConfiguration,
connection,
stream
stream,
maxMsgSize: this.init.maxMsgSize
})
const remotePeerId = connection.remoteAddr.getPeerId()
const webrtcMultiaddr = connection.remoteAddr.encapsulate(`${TRANSPORT}/p2p/${remotePeerId}`)
Expand Down
88 changes: 83 additions & 5 deletions src/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ interface StreamInitOpts {
* Callback to invoke when the stream is closed.
*/
closeCb?: (stream: WebRTCStream) => void

/**
* Max allowed message size to be sent
*/
maxMsgSize?: number
}

/*
Expand Down Expand Up @@ -152,6 +157,74 @@ class StreamState {
}
}

interface MessageProcessor {
send: (bytes: Uint8ArrayList) => void
close?: () => void
}

const createMessageProcessor = (channel: RTCDataChannel, maxMsgSize?: number): MessageProcessor => {
if (maxMsgSize != null) {
/**
* Don't allow channel.bufferedAmount to exceed maxMsgSize
*/
let sendPaused: boolean = false
let sendMessageQueue: Uint8Array[] = []
const processMessageQueue = (): void => {
sendPaused = false
let message = sendMessageQueue.shift()
while (message != null) {
if (channel.bufferedAmount > maxMsgSize) {
sendPaused = true
sendMessageQueue.unshift(message)

const listener = (): void => {
channel.removeEventListener('bufferedamountlow', listener)
processMessageQueue()
}

channel.addEventListener('bufferedamountlow', listener)
return
}

try {
channel.send(message)
message = sendMessageQueue.shift()
} catch (error: any) {
throw new globalThis.Error(`Error send message, reason: ${error.name} - ${error.message}`)
}
}
}

return {
send: (sendbuf: Uint8ArrayList) => {
/**
* Don't allow individual messages to exceed maxMsgSize
*/
let from = 0
let to = Math.min(sendbuf.length, maxMsgSize)
while (to !== from) {
sendMessageQueue.push(sendbuf.subarray(from, to))
from = to
to = Math.min(to + maxMsgSize, sendbuf.length)
}
if (sendPaused) {
return
}

processMessageQueue()
},

close: () => {
sendMessageQueue = []
}
}
} else {
return {
send: (sendbuf: Uint8ArrayList) => { channel.send(sendbuf.subarray()) }
}
}
}

export class WebRTCStream implements Stream {
/**
* Unique identifier for a stream
Expand Down Expand Up @@ -211,10 +284,14 @@ export class WebRTCStream implements Stream {
*/
closeCb?: (stream: WebRTCStream) => void

/**
* Processor for messages that allows throttling if necessary
*/
messageProcessor: MessageProcessor

constructor (opts: StreamInitOpts) {
this.channel = opts.channel
this.id = this.channel.label

this.stat = opts.stat
switch (this.channel.readyState) {
case 'open':
Expand Down Expand Up @@ -254,6 +331,8 @@ export class WebRTCStream implements Stream {
this.abort(err)
}

this.messageProcessor = createMessageProcessor(this.channel, opts.maxMsgSize)

const self = this

// reader pipe
Expand Down Expand Up @@ -311,12 +390,11 @@ export class WebRTCStream implements Stream {
const closeWrite = this._closeWriteIterable()
for await (const buf of merge(closeWrite, src)) {
if (this.streamState.isWriteClosed()) {
this.messageProcessor.close?.()
return
}
const msgbuf = pb.Message.toBinary({ message: buf.subarray() })
const sendbuf = lengthPrefixed.encode.single(msgbuf)

this.channel.send(sendbuf.subarray())
this.messageProcessor.send(lengthPrefixed.encode.single(msgbuf))
}
}

Expand Down Expand Up @@ -435,7 +513,7 @@ export class WebRTCStream implements Stream {
try {
log.trace('Sending flag: %s', flag.toString())
const msgbuf = pb.Message.toBinary({ flag })
this.channel.send(lengthPrefixed.encode.single(msgbuf).subarray())
this.messageProcessor.send(lengthPrefixed.encode.single(msgbuf))
} catch (err) {
if (err instanceof Error) {
log.error(`Exception while sending flag ${flag}: ${err.message}`)
Expand Down
30 changes: 30 additions & 0 deletions test/stream.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@

import * as underTest from '../src/stream'
import { expect } from 'aegir/chai'
import { Uint8ArrayList } from 'uint8arraylist'

const setup = (cb: { send: (bytes: Uint8Array) => void }, maxMsgSize?: number): underTest.WebRTCStream => {
const datachannel = {
readyState: 'open',
send: cb.send

}
return new underTest.WebRTCStream({ channel: datachannel as RTCDataChannel, stat: underTest.defaultStat('outbound'), maxMsgSize })
}

describe('MessageProcessor', () => {
it('handles unconstrained message', () => {
const sent: Uint8Array[] = []
const webrtcStream = setup({ send: (bytes) => sent.push(bytes) })
webrtcStream.messageProcessor.send(new Uint8ArrayList(new Uint8Array(1)))
expect(sent).to.deep.equals([new Uint8Array(1)])
})

it('handles bounded by message size', () => {
const sent: Uint8Array[] = []
const maxMsgSize = 1
const webrtcStream = setup({ send: (bytes) => sent.push(bytes) }, maxMsgSize)
webrtcStream.messageProcessor.send(new Uint8ArrayList(new Uint8Array(2)))
expect(sent).to.deep.equals([new Uint8Array(1), new Uint8Array(1)])
})
})

0 comments on commit f7b8a12

Please sign in to comment.