Skip to content
This repository has been archived by the owner on Jun 19, 2023. It is now read-only.

Commit

Permalink
feat: restrict message sizes and buffered amount
Browse files Browse the repository at this point in the history
  • Loading branch information
marcus-pousette committed May 16, 2023
1 parent f0f4e7c commit a93a3e6
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 39 deletions.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@
"multiformats": "^11.0.2",
"multihashes": "^4.0.3",
"p-defer": "^4.0.0",
"p-event": "^5.0.1",
"protons-runtime": "^5.0.0",
"uint8arraylist": "^2.4.3",
"uint8arrays": "^4.0.3"
Expand Down
9 changes: 9 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ function webRTCDirect (): (components: WebRTCDirectTransportComponents) => Trans
return (components: WebRTCDirectTransportComponents) => new WebRTCDirectTransport(components)
}

/**
* @param {WebRTCTransportInit} init - WebRTC transport configuration
* @param {RTCConfiguration} init.rtcConfiguration - RTCConfiguration
* @param init.dataChannel - DataChannel configurations
* @param {number} init.dataChannel.maxMessageSize - Max message size that can be sent through the DataChannel. Larger messages will be chunked into smaller messages below this size (default 16kb)
* @param {number} init.dataChannel.maxBufferedAmount - Max buffered amount a DataChannel can have (default 16mb)
* @param {number} init.dataChannel.bufferedAmountLowEventTimeout - If max buffered amount is reached, this is the max time that is waited before the buffer is cleared (default 30 seconds)
* @returns
*/
function webRTC (init?: WebRTCTransportInit): (components: WebRTCTransportComponents) => Transport {
return (components: WebRTCTransportComponents) => new WebRTCTransport(components, init ?? {})
}
Expand Down
54 changes: 25 additions & 29 deletions src/muxer.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { WebRTCStream } from './stream.js'
import { type DataChannelOpts, WebRTCStream } from './stream.js'
import { nopSink, nopSource } from './util.js'
import type { Stream } from '@libp2p/interface-connection'
import type { CounterGroup } from '@libp2p/interface-metrics'
Expand All @@ -7,56 +7,55 @@ import type { Source, Sink } from 'it-stream-types'
import type { Uint8ArrayList } from 'uint8arraylist'

export interface DataChannelMuxerFactoryInit {
/**
* WebRTC Peer Connection
*/
peerConnection: RTCPeerConnection

/**
* Optional metrics for this data channel muxer
*/
metrics?: CounterGroup

/**
* Options data channel tasks
*/
dataChannelOptions?: Partial<DataChannelOpts>
}

export class DataChannelMuxerFactory implements StreamMuxerFactory {
/**
* WebRTC Peer Connection
*/
private readonly peerConnection: RTCPeerConnection
private streamBuffer: WebRTCStream[] = []
private readonly metrics?: CounterGroup

constructor (peerConnection: RTCPeerConnection, metrics?: CounterGroup, readonly protocol = '/webrtc') {
this.peerConnection = peerConnection
constructor (readonly init: DataChannelMuxerFactoryInit, readonly protocol = '/webrtc') {
// store any datachannels opened before upgrade has been completed
this.peerConnection.ondatachannel = ({ channel }) => {
this.init.peerConnection.ondatachannel = ({ channel }) => {
const stream = new WebRTCStream({
channel,
stat: {
direction: 'inbound',
timeline: { open: 0 }
},
dataChannelOptions: init.dataChannelOptions,
closeCb: (_stream) => {
this.streamBuffer = this.streamBuffer.filter(s => !_stream.eq(s))
}
})
this.streamBuffer.push(stream)
}
this.metrics = metrics
}

createStreamMuxer (init?: StreamMuxerInit | undefined): StreamMuxer {
return new DataChannelMuxer(this.peerConnection, this.streamBuffer, this.protocol, init, this.metrics)
return new DataChannelMuxer(this.init, this.streamBuffer, this.protocol, init)
}
}

/**
* A libp2p data channel stream muxer
*/
export class DataChannelMuxer implements StreamMuxer {
/**
* WebRTC Peer Connection
*/
private readonly peerConnection: RTCPeerConnection

/**
* Optional metrics for this data channel muxer
*/
private readonly metrics?: CounterGroup

/**
* Array of streams in the data channel
*/
Expand All @@ -82,24 +81,19 @@ export class DataChannelMuxer implements StreamMuxer {
*/
sink: Sink<Source<Uint8Array | Uint8ArrayList>, Promise<void>> = nopSink

constructor (peerConnection: RTCPeerConnection, streams: Stream[], readonly protocol: string = '/webrtc', init?: StreamMuxerInit, metrics?: CounterGroup) {
constructor (readonly dataChannelMuxer: DataChannelMuxerFactoryInit, streams: Stream[], readonly protocol: string = '/webrtc', init?: StreamMuxerInit) {
/**
* Initialized stream muxer
*/
this.init = init

/**
* WebRTC Peer Connection
*/
this.peerConnection = peerConnection

/**
* Fired when a data channel has been added to the connection has been
* added by the remote peer.
*
* {@link https://developer.mozilla.org/en-US/docs/Web/API/RTCPeerConnection/datachannel_event}
*/
this.peerConnection.ondatachannel = ({ channel }) => {
this.dataChannelMuxer.peerConnection.ondatachannel = ({ channel }) => {
const stream = new WebRTCStream({
channel,
stat: {
Expand All @@ -108,12 +102,13 @@ export class DataChannelMuxer implements StreamMuxer {
open: 0
}
},
dataChannelOptions: dataChannelMuxer.dataChannelOptions,
closeCb: this.wrapStreamEnd(init?.onIncomingStream)
})

this.streams.push(stream)
if ((init?.onIncomingStream) != null) {
this.metrics?.increment({ incoming_stream: true })
this.dataChannelMuxer.metrics?.increment({ incoming_stream: true })
init.onIncomingStream(stream)
}
}
Expand All @@ -133,9 +128,9 @@ export class DataChannelMuxer implements StreamMuxer {

newStream (): Stream {
// The spec says the label SHOULD be an empty string: https://github.com/libp2p/specs/blob/master/webrtc/README.md#rtcdatachannel-label
const channel = this.peerConnection.createDataChannel('')
const channel = this.dataChannelMuxer.peerConnection.createDataChannel('')
const closeCb = (stream: Stream): void => {
this.metrics?.increment({ stream_end: true })
this.dataChannelMuxer.metrics?.increment({ stream_end: true })
this.init?.onStreamEnd?.(stream)
}
const stream = new WebRTCStream({
Expand All @@ -146,10 +141,11 @@ export class DataChannelMuxer implements StreamMuxer {
open: 0
}
},
dataChannelOptions: this.dataChannelMuxer.dataChannelOptions,
closeCb: this.wrapStreamEnd(closeCb)
})
this.streams.push(stream)
this.metrics?.increment({ outgoing_stream: true })
this.dataChannelMuxer.metrics?.increment({ outgoing_stream: true })

return stream
}
Expand Down
13 changes: 7 additions & 6 deletions src/private-to-private/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import pDefer, { type DeferredPromise } from 'p-defer'
import { DataChannelMuxerFactory } from '../muxer.js'
import { Message } from './pb/message.js'
import { readCandidatesUntilConnected, resolveOnConnected } from './util.js'
import type { DataChannelOpts } from '../stream.js'
import type { Stream } from '@libp2p/interface-connection'
import type { IncomingStreamData } from '@libp2p/interface-registrar'
import type { StreamMuxerFactory } from '@libp2p/interface-stream-muxer'
Expand All @@ -13,14 +14,13 @@ const DEFAULT_TIMEOUT = 30 * 1000

const log = logger('libp2p:webrtc:peer')

export type IncomingStreamOpts = { rtcConfiguration?: RTCConfiguration } & IncomingStreamData
export type IncomingStreamOpts = { rtcConfiguration?: RTCConfiguration, dataChannelOptions?: Partial<DataChannelOpts> } & IncomingStreamData

export async function handleIncomingStream ({ rtcConfiguration, stream: rawStream }: IncomingStreamOpts): Promise<{ pc: RTCPeerConnection, muxerFactory: StreamMuxerFactory, remoteAddress: string }> {
export async function handleIncomingStream ({ rtcConfiguration, dataChannelOptions, stream: rawStream }: IncomingStreamOpts): Promise<{ pc: RTCPeerConnection, muxerFactory: StreamMuxerFactory, remoteAddress: string }> {
const signal = AbortSignal.timeout(DEFAULT_TIMEOUT)
const stream = pbStream(abortableDuplex(rawStream, signal)).pb(Message)
const pc = new RTCPeerConnection(rtcConfiguration)
const muxerFactory = new DataChannelMuxerFactory(pc)

const muxerFactory = new DataChannelMuxerFactory({ peerConnection: pc, dataChannelOptions })
const connectedPromise: DeferredPromise<void> = pDefer()
const answerSentPromise: DeferredPromise<void> = pDefer()

Expand Down Expand Up @@ -86,13 +86,14 @@ export interface ConnectOptions {
stream: Stream
signal: AbortSignal
rtcConfiguration?: RTCConfiguration
dataChannelOptions?: Partial<DataChannelOpts>
}

export async function initiateConnection ({ rtcConfiguration, signal, stream: rawStream }: ConnectOptions): Promise<{ pc: RTCPeerConnection, muxerFactory: StreamMuxerFactory, remoteAddress: string }> {
export async function initiateConnection ({ rtcConfiguration, dataChannelOptions, signal, stream: rawStream }: ConnectOptions): Promise<{ pc: RTCPeerConnection, muxerFactory: StreamMuxerFactory, remoteAddress: string }> {
const stream = pbStream(abortableDuplex(rawStream, signal)).pb(Message)
// setup peer connection
const pc = new RTCPeerConnection(rtcConfiguration)
const muxerFactory = new DataChannelMuxerFactory(pc)
const muxerFactory = new DataChannelMuxerFactory({ peerConnection: pc, dataChannelOptions })

const connectedPromise: DeferredPromise<void> = pDefer()
resolveOnConnected(pc, connectedPromise)
Expand Down
6 changes: 5 additions & 1 deletion src/private-to-private/transport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { codes } from '../error.js'
import { WebRTCMultiaddrConnection } from '../maconn.js'
import { initiateConnection, handleIncomingStream } from './handler.js'
import { WebRTCPeerListener } from './listener.js'
import type { DataChannelOpts } from '../stream.js'
import type { Connection } from '@libp2p/interface-connection'
import type { Libp2pEvents } from '@libp2p/interface-libp2p'
import type { PeerId } from '@libp2p/interface-peer-id'
Expand All @@ -23,6 +24,7 @@ const WEBRTC_CODE = protocols('webrtc').code

export interface WebRTCTransportInit {
rtcConfiguration?: RTCConfiguration
dataChannel?: Partial<DataChannelOpts>
}

export interface WebRTCTransportComponents {
Expand Down Expand Up @@ -126,6 +128,7 @@ export class WebRTCTransport implements Transport, Startable {
const { pc, muxerFactory, remoteAddress } = await initiateConnection({
stream: signalingStream,
rtcConfiguration: this.init.rtcConfiguration,
dataChannelOptions: this.init.dataChannel,
signal: options.signal
})

Expand Down Expand Up @@ -157,7 +160,8 @@ export class WebRTCTransport implements Transport, Startable {
const { pc, muxerFactory, remoteAddress } = await handleIncomingStream({
rtcConfiguration: this.init.rtcConfiguration,
connection,
stream
stream,
dataChannelOptions: this.init.dataChannel
})

await this.components.upgrader.upgradeInbound(new WebRTCMultiaddrConnection({
Expand Down
2 changes: 1 addition & 1 deletion src/private-to-public/transport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ export class WebRTCDirectTransport implements Transport {
// Track opened peer connection
this.metrics?.dialerEvents.increment({ peer_connection: true })

const muxerFactory = new DataChannelMuxerFactory(peerConnection, this.metrics?.dialerEvents)
const muxerFactory = new DataChannelMuxerFactory({ peerConnection, metrics: this.metrics?.dialerEvents })

// For outbound connections, the remote is expected to start the noise handshake.
// Therefore, we need to secure an inbound noise connection from the remote.
Expand Down
52 changes: 50 additions & 2 deletions src/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import merge from 'it-merge'
import { pipe } from 'it-pipe'
import { pushable } from 'it-pushable'
import defer, { type DeferredPromise } from 'p-defer'
import { pEvent } from 'p-event'
import { Uint8ArrayList } from 'uint8arraylist'
import { Message } from './pb/message.js'
import type { Stream, StreamStat, Direction } from '@libp2p/interface-connection'
Expand All @@ -24,6 +25,12 @@ export function defaultStat (dir: Direction): StreamStat {
}
}

export interface DataChannelOpts {
maxMessageSize: number
maxBufferedAmount: number
bufferedAmountLowEventTimeout: number
}

interface StreamInitOpts {
/**
* The network channel used for bidirectional peer-to-peer transfers of
Expand All @@ -47,6 +54,11 @@ interface StreamInitOpts {
* Callback to invoke when the stream is closed.
*/
closeCb?: (stream: WebRTCStream) => void

/**
* Data channel options
*/
dataChannelOptions?: Partial<DataChannelOpts>
}

/*
Expand Down Expand Up @@ -151,6 +163,15 @@ class StreamState {
}
}

// Max message size that can be sent to the DataChannel
const MAX_MESSAGE_SIZE = 16 * 1024

// How much can be buffered to the DataChannel at once
const MAX_BUFFERED_AMOUNT = 16 * 1024 * 1024

// How long time we wait for the 'bufferedamountlow' event to be emitted
const BUFFERED_AMOUNT_LOW_TIMEOUT = 30 * 1000

export class WebRTCStream implements Stream {
/**
* Unique identifier for a stream
Expand All @@ -177,6 +198,12 @@ export class WebRTCStream implements Stream {
*/
streamState = new StreamState()

/**
* DataChannel contraints
*/

dataChannelOptions: DataChannelOpts

/**
* Read unwrapped protobuf data from the underlying datachannel.
* _src is exposed to the user via the `source` getter to .
Expand Down Expand Up @@ -214,8 +241,14 @@ export class WebRTCStream implements Stream {
this.channel = opts.channel
this.channel.binaryType = 'arraybuffer'
this.id = this.channel.label

this.stat = opts.stat
this.dataChannelOptions = {
bufferedAmountLowEventTimeout: opts.dataChannelOptions?.bufferedAmountLowEventTimeout ?? BUFFERED_AMOUNT_LOW_TIMEOUT,
maxBufferedAmount: opts.dataChannelOptions?.maxBufferedAmount ?? MAX_BUFFERED_AMOUNT,
maxMessageSize: opts.dataChannelOptions?.maxMessageSize ?? MAX_MESSAGE_SIZE
}
this.closeCb = opts.closeCb

switch (this.channel.readyState) {
case 'open':
this.opened.resolve()
Expand Down Expand Up @@ -313,10 +346,25 @@ export class WebRTCStream implements Stream {
if (this.streamState.isWriteClosed()) {
return
}

if (this.channel.bufferedAmount > this.dataChannelOptions.maxBufferedAmount) {
await pEvent(this.channel, 'bufferedamountlow', { timeout: this.dataChannelOptions.bufferedAmountLowEventTimeout }).catch((e) => {
this.close()
throw new Error('Timed out waiting for DataChannel buffer to clear')
})
}

const msgbuf = Message.encode({ message: buf.subarray() })
const sendbuf = lengthPrefixed.encode.single(msgbuf)

this.channel.send(sendbuf.subarray())
while (sendbuf.length > 0) {
if (sendbuf.length <= this.dataChannelOptions.maxMessageSize) {
this.channel.send(sendbuf.subarray())
break
}
this.channel.send(sendbuf.subarray(0, this.dataChannelOptions.maxMessageSize))
sendbuf.consume(this.dataChannelOptions.maxMessageSize)
}
}
}

Expand Down
Loading

0 comments on commit a93a3e6

Please sign in to comment.