Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: close early WebRTC streams properly #2200

Merged
merged 1 commit into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions packages/transport-webrtc/src/muxer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,22 @@ export interface DataChannelMuxerFactoryInit {
dataChannelOptions?: DataChannelOptions
}

interface BufferedStream {
stream: Stream
channel: RTCDataChannel
onEnd(err?: Error): void
}

let streamIndex = 0

export class DataChannelMuxerFactory implements StreamMuxerFactory {
public readonly protocol: string

/**
* WebRTC Peer Connection
*/
private readonly peerConnection: RTCPeerConnection
private streamBuffer: Stream[] = []
private bufferedStreams: BufferedStream[] = []
private readonly metrics?: CounterGroup
private readonly dataChannelOptions?: DataChannelOptions

Expand All @@ -51,15 +59,25 @@ export class DataChannelMuxerFactory implements StreamMuxerFactory {

// store any datachannels opened before upgrade has been completed
this.peerConnection.ondatachannel = ({ channel }) => {
// @ts-expect-error fields are set below
const bufferedStream: BufferedStream = {}

const stream = createStream({
channel,
direction: 'inbound',
onEnd: () => {
this.streamBuffer = this.streamBuffer.filter(s => s.id !== stream.id)
onEnd: (err) => {
bufferedStream.onEnd(err)
},
...this.dataChannelOptions
})
this.streamBuffer.push(stream)

bufferedStream.stream = stream
bufferedStream.channel = channel
bufferedStream.onEnd = () => {
this.bufferedStreams = this.bufferedStreams.filter(s => s.stream.id !== stream.id)
}

this.bufferedStreams.push(bufferedStream)
}
}

Expand All @@ -69,14 +87,14 @@ export class DataChannelMuxerFactory implements StreamMuxerFactory {
peerConnection: this.peerConnection,
dataChannelOptions: this.dataChannelOptions,
metrics: this.metrics,
streams: this.streamBuffer,
streams: this.bufferedStreams,
protocol: this.protocol
})
}
}

export interface DataChannelMuxerInit extends DataChannelMuxerFactoryInit, StreamMuxerInit {
streams: Stream[]
streams: BufferedStream[]
}

/**
Expand All @@ -94,7 +112,7 @@ export class DataChannelMuxer implements StreamMuxer {
private readonly metrics?: CounterGroup

constructor (readonly init: DataChannelMuxerInit) {
this.streams = init.streams
this.streams = init.streams.map(s => s.stream)
this.peerConnection = init.peerConnection
this.protocol = init.protocol ?? PROTOCOL
this.metrics = init.metrics
Expand All @@ -111,11 +129,7 @@ export class DataChannelMuxer implements StreamMuxer {
channel,
direction: 'inbound',
onEnd: () => {
log.trace('stream %s %s %s onEnd', stream.direction, stream.id, stream.protocol)
drainAndClose(channel, `inbound ${stream.id} ${stream.protocol}`, this.dataChannelOptions.drainTimeout)
this.streams = this.streams.filter(s => s.id !== stream.id)
this.metrics?.increment({ stream_end: true })
init?.onStreamEnd?.(stream)
this.#onStreamEnd(stream, channel)
},
...this.dataChannelOptions
})
Expand All @@ -125,10 +139,22 @@ export class DataChannelMuxer implements StreamMuxer {
init?.onIncomingStream?.(stream)
}

const onIncomingStream = init?.onIncomingStream
if (onIncomingStream != null) {
this.streams.forEach(s => { onIncomingStream(s) })
}
this.init.streams.forEach(bufferedStream => {
bufferedStream.onEnd = () => {
this.#onStreamEnd(bufferedStream.stream, bufferedStream.channel)
}

this.metrics?.increment({ incoming_stream: true })
this.init?.onIncomingStream?.(bufferedStream.stream)
})
}

#onStreamEnd (stream: Stream, channel: RTCDataChannel): void {
log.trace('stream %s %s %s onEnd', stream.direction, stream.id, stream.protocol)
drainAndClose(channel, `${stream.direction} ${stream.id} ${stream.protocol}`, this.dataChannelOptions.drainTimeout)
this.streams = this.streams.filter(s => s.id !== stream.id)
this.metrics?.increment({ stream_end: true })
this.init?.onStreamEnd?.(stream)
}

/**
Expand Down Expand Up @@ -164,17 +190,15 @@ export class DataChannelMuxer implements StreamMuxer {
sink: Sink<Source<Uint8Array | Uint8ArrayList>, Promise<void>> = nopSink

newStream (): Stream {
streamIndex++

// The spec says the label SHOULD be an empty string: https://github.com/libp2p/specs/blob/master/webrtc/README.md#rtcdatachannel-label
const channel = this.peerConnection.createDataChannel('')
const channel = this.peerConnection.createDataChannel(`stream-${streamIndex}`)
const stream = createStream({
channel,
direction: 'outbound',
onEnd: () => {
log.trace('stream %s %s %s onEnd', stream.direction, stream.id, stream.protocol)
drainAndClose(channel, `outbound ${stream.id} ${stream.protocol}`, this.dataChannelOptions.drainTimeout)
this.streams = this.streams.filter(s => s.id !== stream.id)
this.metrics?.increment({ stream_end: true })
this.init?.onStreamEnd?.(stream)
this.#onStreamEnd(stream, channel)
},
...this.dataChannelOptions
})
Expand Down
12 changes: 5 additions & 7 deletions packages/transport-webrtc/src/private-to-private/transport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,6 @@ export class WebRTCTransport implements Transport, Startable {
log.trace('dialing address: %a', ma)

const peerConnection = new RTCPeerConnection(this.init.rtcConfiguration)
const muxerFactory = new DataChannelMuxerFactory({
peerConnection,
dataChannelOptions: this.init.dataChannel
})

const { remoteAddress } = await initiateConnection({
peerConnection,
Expand All @@ -145,7 +141,10 @@ export class WebRTCTransport implements Transport, Startable {
const connection = await options.upgrader.upgradeOutbound(webRTCConn, {
skipProtection: true,
skipEncryption: true,
muxerFactory
muxerFactory: new DataChannelMuxerFactory({
peerConnection,
dataChannelOptions: this.init.dataChannel
})
})

// close the connection on shut down
Expand All @@ -157,7 +156,6 @@ export class WebRTCTransport implements Transport, Startable {
async _onProtocol ({ connection, stream }: IncomingStreamData): Promise<void> {
const signal = AbortSignal.timeout(this.init.inboundConnectionTimeout ?? INBOUND_CONNECTION_TIMEOUT)
const peerConnection = new RTCPeerConnection(this.init.rtcConfiguration)
const muxerFactory = new DataChannelMuxerFactory({ peerConnection, dataChannelOptions: this.init.dataChannel })

try {
const { remoteAddress } = await handleIncomingStream({
Expand All @@ -180,7 +178,7 @@ export class WebRTCTransport implements Transport, Startable {
await this.components.upgrader.upgradeInbound(webRTCConn, {
skipEncryption: true,
skipProtection: true,
muxerFactory
muxerFactory: new DataChannelMuxerFactory({ peerConnection, dataChannelOptions: this.init.dataChannel })
})

// close the stream if SDP messages have been exchanged successfully
Expand Down
1 change: 0 additions & 1 deletion packages/transport-webrtc/src/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ export class WebRTCStream extends AbstractStream {
*/
private readonly receiveFinAck: DeferredPromise<void>
private readonly finAckTimeout: number
// private sentFinAck: boolean

constructor (init: WebRTCStreamInit) {
// override onEnd to send/receive FIN_ACK before closing the stream
Expand Down