From 13df7aeec79d8060178aa50a5bf40e4190738386 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antoine=20Bach=C3=A9?= Date: Sat, 26 Jun 2021 14:04:12 +0200 Subject: [PATCH] Implement DTLS/SRTP/SCTP restart Fixes #1636 --- datachannel.go | 20 ++- dtlstransport.go | 25 ++++ peerconnection.go | 56 ++++++- peerconnection_media_test.go | 282 +++++++++++++++++++++++++++++++++++ sctptransport.go | 65 +++++++- 5 files changed, 440 insertions(+), 8 deletions(-) diff --git a/datachannel.go b/datachannel.go index 1d9489546c9..4bcd9438312 100644 --- a/datachannel.go +++ b/datachannel.go @@ -69,7 +69,7 @@ func (api *API) NewDataChannel(transport *SCTPTransport, params *DataChannelPara return nil, err } - err = d.open(transport) + err = d.open(transport, false) if err != nil { return nil, err } @@ -103,14 +103,14 @@ func (api *API) newDataChannel(params *DataChannelParameters, log logging.Levele } // open opens the datachannel over the sctp transport -func (d *DataChannel) open(sctpTransport *SCTPTransport) error { +func (d *DataChannel) open(sctpTransport *SCTPTransport, restart bool) error { association := sctpTransport.association() if association == nil { return errSCTPNotEstablished } d.mu.Lock() - if d.sctpTransport != nil { // already open + if d.sctpTransport != nil && !restart { // already open & not restarting d.mu.Unlock() return nil } @@ -164,6 +164,11 @@ func (d *DataChannel) open(sctpTransport *SCTPTransport) error { return err } + // If restarting, the `Open` event should be triggered again, once. + if restart { + d.openHandlerOnce = sync.Once{} + } + // bufferedAmountLowThreshold and onBufferedAmountLow might be set earlier dc.SetBufferedAmountLowThreshold(d.bufferedAmountLowThreshold) dc.OnBufferedAmountLow(d.onBufferedAmountLow) @@ -309,11 +314,18 @@ func (d *DataChannel) readLoop() { n, isString, err := d.dataChannel.ReadDataChannel(buffer) if err != nil { rlBufPool.Put(buffer) // nolint:staticcheck + + previousState := d.ReadyState() d.setReadyState(DataChannelStateClosed) + if err != io.EOF { d.onError(err) } - d.onClose() + + // https://www.w3.org/TR/webrtc/#announcing-a-data-channel-as-closed + if previousState != DataChannelStateClosed { + d.onClose() + } return } diff --git a/dtlstransport.go b/dtlstransport.go index 9b69b2d700f..a1a7000faf9 100644 --- a/dtlstransport.go +++ b/dtlstransport.go @@ -213,6 +213,31 @@ func (t *DTLSTransport) startSRTP() error { return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err) } + isAlreadyRunning := func() bool { + select { + case <-t.srtpReady: + return true + default: + return false + } + }() + + if isAlreadyRunning { + if sess, ok := t.srtpSession.Load().(*srtp.SessionSRTP); ok { + if updateErr := sess.UpdateContext(srtpConfig); updateErr != nil { + return updateErr + } + } + + if sess, ok := t.srtcpSession.Load().(*srtp.SessionSRTCP); ok { + if updateErr := sess.UpdateContext(srtpConfig); updateErr != nil { + return updateErr + } + } + + return nil + } + srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig) if err != nil { return fmt.Errorf("%w: %v", errFailedToStartSRTP, err) diff --git a/peerconnection.go b/peerconnection.go index fbc334e9cca..b3e5c6642ba 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -1108,7 +1108,59 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error { pc.ops.Enqueue(func() { pc.startRTP(true, &desc, currentTransceivers) }) + } else if pc.dtlsTransport.State() != DTLSTransportStateNew { + fingerprint, fingerprintHash, fErr := extractFingerprint(desc.parsed) + if fErr != nil { + return fErr + } + + fingerPrintDidChange := true + + for _, fp := range pc.dtlsTransport.remoteParameters.Fingerprints { + if fingerprint == fp.Value && fingerprintHash == fp.Algorithm { + fingerPrintDidChange = false + break + } + } + + if fingerPrintDidChange { + pc.ops.Enqueue(func() { + // SCTP uses DTLS, so prevent any use, by locking, while + // DTLS is restarting. + pc.sctpTransport.lock.Lock() + defer pc.sctpTransport.lock.Unlock() + + if dErr := pc.dtlsTransport.Stop(); dErr != nil { + pc.log.Warnf("Failed to stop DTLS: %s", dErr) + } + + // libwebrtc switches the connection back to `new`. + pc.dtlsTransport.lock.Lock() + pc.dtlsTransport.onStateChange(DTLSTransportStateNew) + pc.dtlsTransport.lock.Unlock() + + // Restart the dtls transport with updated fingerprints + err = pc.dtlsTransport.Start(DTLSParameters{ + Role: dtlsRoleFromRemoteSDP(desc.parsed), + Fingerprints: []DTLSFingerprint{{Algorithm: fingerprintHash, Value: fingerprint}}, + }) + pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State()) + if err != nil { + pc.log.Warnf("Failed to restart DTLS: %s", err) + return + } + + // If SCTP was enabled, restart it with the new DTLS transport. + if pc.sctpTransport.isStarted { + if dErr := pc.sctpTransport.restart(pc.dtlsTransport.conn); dErr != nil { + pc.log.Warnf("Failed to restart SCTP: %s", dErr) + return + } + } + }) + } } + return nil } @@ -1317,7 +1369,7 @@ func (pc *PeerConnection) startSCTP() { var openedDCCount uint32 for _, d := range dataChannels { if d.ReadyState() == DataChannelStateConnecting { - err := d.open(pc.sctpTransport) + err := d.open(pc.sctpTransport, false) if err != nil { pc.log.Warnf("failed to open data channel: %s", err) continue @@ -1775,7 +1827,7 @@ func (pc *PeerConnection) CreateDataChannel(label string, options *DataChannelIn // If SCTP already connected open all the channels if pc.sctpTransport.State() == SCTPTransportStateConnected { - if err = d.open(pc.sctpTransport); err != nil { + if err = d.open(pc.sctpTransport, false); err != nil { return nil, err } } diff --git a/peerconnection_media_test.go b/peerconnection_media_test.go index 0ec4eeb059e..1dd3b048079 100644 --- a/peerconnection_media_test.go +++ b/peerconnection_media_test.go @@ -1052,3 +1052,285 @@ func TestPeerConnection_RaceReplaceTrack(t *testing.T) { assert.NoError(t, pc.Close()) } + +// Issue #1636 +func TestPeerConnection_DTLS_Restart_MediaAndDataChannel(t *testing.T) { + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + makeClient := func() (*PeerConnection, *TrackLocalStaticSample, *DataChannel, <-chan string) { + pc, cliErr := NewPeerConnection(Configuration{}) + assert.NoError(t, cliErr) + + track, cliErr := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeOpus}, "audio", "test-client") + assert.NoError(t, cliErr) + + dc, cliErr := pc.CreateDataChannel("data", nil) + assert.NoError(t, cliErr) + + dataChannelMessages := make(chan string, 1) + pc.OnDataChannel(func(channel *DataChannel) { + channel.OnMessage(func(msg DataChannelMessage) { + fmt.Printf("received %v\n", string(msg.Data)) + dataChannelMessages <- string(msg.Data) + }) + }) + + _, cliErr = pc.AddTrack(track) + assert.NoError(t, cliErr) + + return pc, track, dc, dataChannelMessages + } + + pcA1, _, _, msgsA1 := makeClient() + defer func() { _ = pcA1.Close() }() + pcA2, _, _, msgsA2 := makeClient() + defer func() { _ = pcA2.Close() }() + pcB, outputTrackB, dataChannelB, _ := makeClient() + defer func() { _ = pcB.Close() }() + + triggerMedia := func() { + assert.NoError(t, outputTrackB.WriteSample(media.Sample{ + Data: []byte{0xbb}, + Timestamp: time.Now(), + Duration: 20 * time.Millisecond, + })) + // Somehow, if we only send 1 packet, the OnTrack event is never fired. + time.Sleep(20 * time.Millisecond) + assert.NoError(t, outputTrackB.WriteSample(media.Sample{ + Data: []byte{0xbb}, + Timestamp: time.Now(), + Duration: 20 * time.Millisecond, + })) + } + + gatherCompletePromiseA1 := GatheringCompletePromise(pcA1) + offerA1, err := pcA1.CreateOffer(nil) + assert.NoError(t, err) + assert.NoError(t, pcA1.SetLocalDescription(offerA1)) + <-gatherCompletePromiseA1 + + assert.NoError(t, pcB.SetRemoteDescription(*pcA1.LocalDescription())) + + gatherCompletePromiseB := GatheringCompletePromise(pcB) + answerB, err := pcB.CreateAnswer(nil) + assert.NoError(t, err) + assert.NoError(t, pcB.SetLocalDescription(answerB)) + <-gatherCompletePromiseB + + pcA1Connected := make(chan struct{}, 1) + pcA1.OnICEConnectionStateChange(func(s ICEConnectionState) { + if s == ICEConnectionStateConnected { + pcA1Connected <- struct{}{} + } + }) + + incomingTracksA1 := make(chan *TrackRemote, 1) + pcA1.OnTrack(func(remote *TrackRemote, receiver *RTPReceiver) { + incomingTracksA1 <- remote + }) + + assert.NoError(t, pcA1.SetRemoteDescription(answerB)) + + <-pcA1Connected + + triggerMedia() + assert.NoError(t, dataChannelB.SendText("HelloWorld")) + + incomingTrackA1 := <-incomingTracksA1 + + pkt, _, err := incomingTrackA1.ReadRTP() + assert.NotNil(t, pkt) + assert.NoError(t, err) + + <-msgsA1 + assert.Empty(t, msgsA1) + + // ClientA2 connects to ClientB + + gatherCompletePromiseA2 := GatheringCompletePromise(pcA2) + // We can't do an ICE Restart here, since it's a different PeerConnection + offerA2, err := pcA2.CreateOffer(nil) + assert.NoError(t, err) + assert.NoError(t, pcA2.SetLocalDescription(offerA2)) + <-gatherCompletePromiseA2 + + assert.NoError(t, pcB.SetRemoteDescription(*pcA2.LocalDescription())) + + gatherCompletePromiseB = GatheringCompletePromise(pcB) + answerB, err = pcB.CreateAnswer(nil) + assert.NoError(t, err) + assert.NoError(t, pcB.SetLocalDescription(answerB)) + <-gatherCompletePromiseB + + pcA2Connected := make(chan struct{}, 1) + pcA2.OnICEConnectionStateChange(func(s ICEConnectionState) { + if s == ICEConnectionStateConnected { + pcA2Connected <- struct{}{} + } + }) + + incomingTracksA2 := make(chan *TrackRemote, 1) + pcA2.OnTrack(func(remote *TrackRemote, receiver *RTPReceiver) { + incomingTracksA2 <- remote + }) + + assert.NoError(t, pcA2.SetRemoteDescription(answerB)) + + // Wait for connection + <-pcA2Connected + + triggerMedia() + assert.NoError(t, dataChannelB.SendText("HelloWorld")) + + // Make sure A1 doesn't receive anything + assert.Empty(t, incomingTracksA1) + assert.NoError(t, incomingTrackA1.SetReadDeadline(time.Now().Add(100*time.Millisecond))) + pkt, _, err = incomingTrackA1.ReadRTP() + assert.Nil(t, pkt) + assert.Error(t, err) + + // Needed in case of `-race`?? + triggerMedia() + + // Make sure A2 receives media + incomingTrackA2 := <-incomingTracksA2 + pkt, _, err = incomingTrackA2.ReadRTP() + assert.NotNil(t, pkt) + assert.NoError(t, err) + + <-msgsA2 + assert.Empty(t, msgsA2) + assert.Empty(t, msgsA1) +} + +// Issue #1636 +func TestPeerConnection_DTLS_Restart_MediaOnly(t *testing.T) { + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + makeClient := func() (*PeerConnection, *TrackLocalStaticSample) { + pc, cliErr := NewPeerConnection(Configuration{}) + assert.NoError(t, cliErr) + + track, cliErr := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeOpus}, "audio", "test-client") + assert.NoError(t, cliErr) + + _, cliErr = pc.AddTrack(track) + assert.NoError(t, cliErr) + + return pc, track + } + + pcA1, _ := makeClient() + defer func() { _ = pcA1.Close() }() + pcA2, _ := makeClient() + defer func() { _ = pcA2.Close() }() + pcB, outputTrackB := makeClient() + defer func() { _ = pcB.Close() }() + + triggerMedia := func() { + assert.NoError(t, outputTrackB.WriteSample(media.Sample{ + Data: []byte{0xbb}, + Timestamp: time.Now(), + Duration: 20 * time.Millisecond, + })) + // Somehow, if we only send 1 packet, the OnTrack event is never fired. + time.Sleep(20 * time.Millisecond) + assert.NoError(t, outputTrackB.WriteSample(media.Sample{ + Data: []byte{0xbb}, + Timestamp: time.Now(), + Duration: 20 * time.Millisecond, + })) + } + + gatherCompletePromiseA1 := GatheringCompletePromise(pcA1) + offerA1, err := pcA1.CreateOffer(nil) + assert.NoError(t, err) + assert.NoError(t, pcA1.SetLocalDescription(offerA1)) + <-gatherCompletePromiseA1 + + assert.NoError(t, pcB.SetRemoteDescription(*pcA1.LocalDescription())) + + gatherCompletePromiseB := GatheringCompletePromise(pcB) + answerB, err := pcB.CreateAnswer(nil) + assert.NoError(t, err) + assert.NoError(t, pcB.SetLocalDescription(answerB)) + <-gatherCompletePromiseB + + pcA1Connected := make(chan struct{}, 1) + pcA1.OnICEConnectionStateChange(func(s ICEConnectionState) { + if s == ICEConnectionStateConnected { + pcA1Connected <- struct{}{} + } + }) + + incomingTracksA1 := make(chan *TrackRemote, 1) + pcA1.OnTrack(func(remote *TrackRemote, receiver *RTPReceiver) { + incomingTracksA1 <- remote + }) + + assert.NoError(t, pcA1.SetRemoteDescription(answerB)) + + <-pcA1Connected + + triggerMedia() + + incomingTrackA1 := <-incomingTracksA1 + + pkt, _, err := incomingTrackA1.ReadRTP() + assert.NotNil(t, pkt) + assert.NoError(t, err) + + // ClientA2 connects to ClientB + + gatherCompletePromiseA2 := GatheringCompletePromise(pcA2) + // We can't do an ICE Restart here, since it's a different PeerConnection + offerA2, err := pcA2.CreateOffer(nil) + assert.NoError(t, err) + assert.NoError(t, pcA2.SetLocalDescription(offerA2)) + <-gatherCompletePromiseA2 + + assert.NoError(t, pcB.SetRemoteDescription(*pcA2.LocalDescription())) + + gatherCompletePromiseB = GatheringCompletePromise(pcB) + answerB, err = pcB.CreateAnswer(nil) + assert.NoError(t, err) + assert.NoError(t, pcB.SetLocalDescription(answerB)) + <-gatherCompletePromiseB + + pcA2Connected := make(chan struct{}, 1) + pcA2.OnICEConnectionStateChange(func(s ICEConnectionState) { + if s == ICEConnectionStateConnected { + pcA2Connected <- struct{}{} + } + }) + + incomingTracksA2 := make(chan *TrackRemote, 1) + pcA2.OnTrack(func(remote *TrackRemote, receiver *RTPReceiver) { + incomingTracksA2 <- remote + }) + + assert.NoError(t, pcA2.SetRemoteDescription(answerB)) + + // Wait for connection + <-pcA2Connected + + triggerMedia() + + // Make sure A1 doesn't receive anything + assert.Empty(t, incomingTracksA1) + assert.NoError(t, incomingTrackA1.SetReadDeadline(time.Now().Add(500*time.Millisecond))) + pkt, _, err = incomingTrackA1.ReadRTP() + assert.Nil(t, pkt) + assert.Error(t, err) + + // This is needed in case of `-race`, somehow... + triggerMedia() + + // Make sure A2 receives media + incomingTrackA2 := <-incomingTracksA2 + pkt, _, err = incomingTrackA2.ReadRTP() + assert.NotNil(t, pkt) + assert.NoError(t, err) +} diff --git a/sctptransport.go b/sctptransport.go index 6718a615040..3c61c1e267f 100644 --- a/sctptransport.go +++ b/sctptransport.go @@ -6,9 +6,11 @@ import ( "io" "math" "sync" + "sync/atomic" "time" "github.com/pion/datachannel" + "github.com/pion/dtls/v2" "github.com/pion/logging" "github.com/pion/sctp" "github.com/pion/webrtc/v3/pkg/rtcerr" @@ -29,6 +31,8 @@ type SCTPTransport struct { // so we need a dedicated field isStarted bool + isAcceptLoopRunning uint32 // Used as a bool + // MaxMessageSize represents the maximum size of data that can be passed to // DataChannel's send() method. maxMessageSize float64 @@ -115,7 +119,45 @@ func (r *SCTPTransport) Start(remoteCaps SCTPCapabilities) error { r.sctpAssociation = sctpAssociation r.state = SCTPTransportStateConnected - go r.acceptDataChannels(sctpAssociation) + go r.acceptDataChannels() + + return nil +} + +// Caller must hold lock +func (r *SCTPTransport) restart(dtlsConn *dtls.Conn) error { + sctpAssociation, err := sctp.Client(sctp.Config{ + NetConn: dtlsConn, + LoggerFactory: r.api.settingEngine.LoggerFactory, + }) + if err != nil { + return err + } + + r.sctpAssociation = sctpAssociation + + // Snapshots the DataChannels to process them asynchronously, safely. + // If a DataChannel is closed, it was most likely closed because of the + // reconnection & needs to be restarted. + dataChannelsCpy := make([]*DataChannel, len(r.dataChannels)) + copy(dataChannelsCpy, r.dataChannels) + + go func(dataChannels []*DataChannel) { + for _, d := range dataChannels { + if d.ReadyState() == DataChannelStateClosed { + err := d.open(r, true) + if err != nil { + r.log.Warnf("failed to re-open data channel: %s", err) + continue + } + } + } + }(dataChannelsCpy) + + if swapped := atomic.CompareAndSwapUint32(&r.isAcceptLoopRunning, 0, 1); swapped { + // AcceptLoop wasn't running, restart it + go r.acceptDataChannels() + } return nil } @@ -138,13 +180,32 @@ func (r *SCTPTransport) Stop() error { return nil } -func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) { +func (r *SCTPTransport) acceptDataChannels() { + atomic.StoreUint32(&r.isAcceptLoopRunning, 1) + defer atomic.StoreUint32(&r.isAcceptLoopRunning, 0) + for { + r.lock.RLock() + a := r.sctpAssociation + r.lock.RUnlock() + dc, err := datachannel.Accept(a, &datachannel.Config{ LoggerFactory: r.api.settingEngine.LoggerFactory, }) if err != nil { if err != io.EOF { + r.lock.RLock() + didRestart := r.sctpAssociation != a + r.lock.RUnlock() + if didRestart { + // During a restart, `Accept` will return an `EOF`. + // If the association is different, it means the SCTPTransport just + // performed a `restart`. In that case, just keep looping. + // + // This is safe since restarts are performed holding the `lock`. + continue + } + r.log.Errorf("Failed to accept data channel: %v", err) r.onError(err) }