diff --git a/datachannel.go b/datachannel.go index 4af5ac940fa..11d09e19283 100644 --- a/datachannel.go +++ b/datachannel.go @@ -70,7 +70,7 @@ func (api *API) NewDataChannel(transport *SCTPTransport, params *DataChannelPara return nil, err } - err = d.open(transport) + err = d.open(transport, false) if err != nil { return nil, err } @@ -104,14 +104,14 @@ func (api *API) newDataChannel(params *DataChannelParameters, log logging.Levele } // open opens the datachannel over the sctp transport -func (d *DataChannel) open(sctpTransport *SCTPTransport) error { +func (d *DataChannel) open(sctpTransport *SCTPTransport, restart bool) error { association := sctpTransport.association() if association == nil { return errSCTPNotEstablished } d.mu.Lock() - if d.sctpTransport != nil { // already open + if d.sctpTransport != nil && !restart { // already open & not restarting d.mu.Unlock() return nil } @@ -170,6 +170,11 @@ func (d *DataChannel) open(sctpTransport *SCTPTransport) error { return err } + // If restarting, the `Open` event should be triggered again, once. + if restart { + d.openHandlerOnce = sync.Once{} + } + // bufferedAmountLowThreshold and onBufferedAmountLow might be set earlier dc.SetBufferedAmountLowThreshold(d.bufferedAmountLowThreshold) dc.OnBufferedAmountLow(d.onBufferedAmountLow) @@ -325,11 +330,18 @@ func (d *DataChannel) readLoop() { n, isString, err := d.dataChannel.ReadDataChannel(buffer) if err != nil { rlBufPool.Put(buffer) // nolint:staticcheck + + previousState := d.ReadyState() d.setReadyState(DataChannelStateClosed) + if err != io.EOF { d.onError(err) } - d.onClose() + + // https://www.w3.org/TR/webrtc/#announcing-a-data-channel-as-closed + if previousState != DataChannelStateClosed { + d.onClose() + } return } diff --git a/dtlstransport.go b/dtlstransport.go index 560eb4b84e4..c9ef0d538d7 100644 --- a/dtlstransport.go +++ b/dtlstransport.go @@ -215,6 +215,31 @@ func (t *DTLSTransport) startSRTP() error { return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err) } + isAlreadyRunning := func() bool { + select { + case <-t.srtpReady: + return true + default: + return false + } + }() + + if isAlreadyRunning { + if sess, ok := t.srtpSession.Load().(*srtp.SessionSRTP); ok { + if updateErr := sess.UpdateContext(srtpConfig); updateErr != nil { + return updateErr + } + } + + if sess, ok := t.srtcpSession.Load().(*srtp.SessionSRTCP); ok { + if updateErr := sess.UpdateContext(srtpConfig); updateErr != nil { + return updateErr + } + } + + return nil + } + srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig) if err != nil { return fmt.Errorf("%w: %v", errFailedToStartSRTP, err) diff --git a/go.mod b/go.mod index 0ddea8ff553..054639fc0f5 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/pion/rtp v1.7.4 github.com/pion/sctp v1.8.2 github.com/pion/sdp/v3 v3.0.4 - github.com/pion/srtp/v2 v2.0.5 + github.com/pion/srtp/v2 v2.0.6-0.20220304062923-d55e443f8e15 github.com/pion/transport v0.13.0 github.com/sclevine/agouti v3.0.0+incompatible github.com/stretchr/testify v1.7.0 diff --git a/go.sum b/go.sum index a71303dc9f2..2b9596c7fb8 100644 --- a/go.sum +++ b/go.sum @@ -54,10 +54,8 @@ github.com/pion/mdns v0.0.5 h1:Q2oj/JB3NqfzY9xGZ1fPzZzK7sDSD8rZPOvcIQ10BCw= github.com/pion/mdns v0.0.5/go.mod h1:UgssrvdD3mxpi8tMxAXbsppL3vJ4Jipw1mTCW+al01g= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= -github.com/pion/rtcp v1.2.6/go.mod h1:52rMNPWFsjr39z9B9MhnkqhPLoeHTv1aN63o/42bWE0= github.com/pion/rtcp v1.2.9 h1:1ujStwg++IOLIEoOiIQ2s+qBuJ1VN81KW+9pMPsif+U= github.com/pion/rtcp v1.2.9/go.mod h1:qVPhiCzAm4D/rxb6XzKeyZiQK69yJpbUDJSF7TgrqNo= -github.com/pion/rtp v1.7.0/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko= github.com/pion/rtp v1.7.4 h1:4dMbjb1SuynU5OpA3kz1zHK+u+eOCQjW3MAeVHf1ODA= github.com/pion/rtp v1.7.4/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko= github.com/pion/sctp v1.8.0/go.mod h1:xFe9cLMZ5Vj6eOzpyiKjT9SwGM4KpK/8Jbw5//jc+0s= @@ -65,8 +63,8 @@ github.com/pion/sctp v1.8.2 h1:yBBCIrUMJ4yFICL3RIvR4eh/H2BTTvlligmSTy+3kiA= github.com/pion/sctp v1.8.2/go.mod h1:xFe9cLMZ5Vj6eOzpyiKjT9SwGM4KpK/8Jbw5//jc+0s= github.com/pion/sdp/v3 v3.0.4 h1:2Kf+dgrzJflNCSw3TV5v2VLeI0s/qkzy2r5jlR0wzf8= github.com/pion/sdp/v3 v3.0.4/go.mod h1:bNiSknmJE0HYBprTHXKPQ3+JjacTv5uap92ueJZKsRk= -github.com/pion/srtp/v2 v2.0.5 h1:ks3wcTvIUE/GHndO3FAvROQ9opy0uLELpwHJaQ1yqhQ= -github.com/pion/srtp/v2 v2.0.5/go.mod h1:8k6AJlal740mrZ6WYxc4Dg6qDqqhxoRG2GSjlUhDF0A= +github.com/pion/srtp/v2 v2.0.6-0.20220304062923-d55e443f8e15 h1:qFdF9b185eGmlBr1OyizpP4a8RnS8tCT8MztCKofuQ8= +github.com/pion/srtp/v2 v2.0.6-0.20220304062923-d55e443f8e15/go.mod h1:Kp632EOcOX2wtB6njSY+oRamReUfEYINuaGmKIMHVlA= github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg= github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA= github.com/pion/transport v0.12.2/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q= diff --git a/peerconnection.go b/peerconnection.go index 1794e8d41ef..6b64c9eea79 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -1144,7 +1144,59 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error { pc.ops.Enqueue(func() { pc.startRTP(true, &desc, currentTransceivers) }) + } else if pc.dtlsTransport.State() != DTLSTransportStateNew { + fingerprint, fingerprintHash, fErr := extractFingerprint(desc.parsed) + if fErr != nil { + return fErr + } + + fingerPrintDidChange := true + + for _, fp := range pc.dtlsTransport.remoteParameters.Fingerprints { + if fingerprint == fp.Value && fingerprintHash == fp.Algorithm { + fingerPrintDidChange = false + break + } + } + + if fingerPrintDidChange { + pc.ops.Enqueue(func() { + // SCTP uses DTLS, so prevent any use, by locking, while + // DTLS is restarting. + pc.sctpTransport.lock.Lock() + defer pc.sctpTransport.lock.Unlock() + + if dErr := pc.dtlsTransport.Stop(); dErr != nil { + pc.log.Warnf("Failed to stop DTLS: %s", dErr) + } + + // libwebrtc switches the connection back to `new`. + pc.dtlsTransport.lock.Lock() + pc.dtlsTransport.onStateChange(DTLSTransportStateNew) + pc.dtlsTransport.lock.Unlock() + + // Restart the dtls transport with updated fingerprints + err = pc.dtlsTransport.Start(DTLSParameters{ + Role: dtlsRoleFromRemoteSDP(desc.parsed), + Fingerprints: []DTLSFingerprint{{Algorithm: fingerprintHash, Value: fingerprint}}, + }) + pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State()) + if err != nil { + pc.log.Warnf("Failed to restart DTLS: %s", err) + return + } + + // If SCTP was enabled, restart it with the new DTLS transport. + if pc.sctpTransport.isStarted { + if dErr := pc.sctpTransport.restart(pc.dtlsTransport.conn); dErr != nil { + pc.log.Warnf("Failed to restart SCTP: %s", dErr) + return + } + } + }) + } } + return nil } @@ -1904,7 +1956,7 @@ func (pc *PeerConnection) CreateDataChannel(label string, options *DataChannelIn // If SCTP already connected open all the channels if pc.sctpTransport.State() == SCTPTransportStateConnected { - if err = d.open(pc.sctpTransport); err != nil { + if err = d.open(pc.sctpTransport, false); err != nil { return nil, err } } diff --git a/peerconnection_media_test.go b/peerconnection_media_test.go index 6390abfab8e..1af132c95ab 100644 --- a/peerconnection_media_test.go +++ b/peerconnection_media_test.go @@ -1336,3 +1336,285 @@ func TestPeerConnection_Simulcast(t *testing.T) { closePairNow(t, pcOffer, pcAnswer) }) } + +// Issue #1636 +func TestPeerConnection_DTLS_Restart_MediaAndDataChannel(t *testing.T) { + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + makeClient := func() (*PeerConnection, *TrackLocalStaticSample, *DataChannel, <-chan string) { + pc, cliErr := NewPeerConnection(Configuration{}) + assert.NoError(t, cliErr) + + track, cliErr := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeOpus}, "audio", "test-client") + assert.NoError(t, cliErr) + + dc, cliErr := pc.CreateDataChannel("data", nil) + assert.NoError(t, cliErr) + + dataChannelMessages := make(chan string, 1) + pc.OnDataChannel(func(channel *DataChannel) { + channel.OnMessage(func(msg DataChannelMessage) { + fmt.Printf("received %v\n", string(msg.Data)) + dataChannelMessages <- string(msg.Data) + }) + }) + + _, cliErr = pc.AddTrack(track) + assert.NoError(t, cliErr) + + return pc, track, dc, dataChannelMessages + } + + pcA1, _, _, msgsA1 := makeClient() + defer func() { _ = pcA1.Close() }() + pcA2, _, _, msgsA2 := makeClient() + defer func() { _ = pcA2.Close() }() + pcB, outputTrackB, dataChannelB, _ := makeClient() + defer func() { _ = pcB.Close() }() + + triggerMedia := func() { + assert.NoError(t, outputTrackB.WriteSample(media.Sample{ + Data: []byte{0xbb}, + Timestamp: time.Now(), + Duration: 20 * time.Millisecond, + })) + // Somehow, if we only send 1 packet, the OnTrack event is never fired. + time.Sleep(20 * time.Millisecond) + assert.NoError(t, outputTrackB.WriteSample(media.Sample{ + Data: []byte{0xbb}, + Timestamp: time.Now(), + Duration: 20 * time.Millisecond, + })) + } + + gatherCompletePromiseA1 := GatheringCompletePromise(pcA1) + offerA1, err := pcA1.CreateOffer(nil) + assert.NoError(t, err) + assert.NoError(t, pcA1.SetLocalDescription(offerA1)) + <-gatherCompletePromiseA1 + + assert.NoError(t, pcB.SetRemoteDescription(*pcA1.LocalDescription())) + + gatherCompletePromiseB := GatheringCompletePromise(pcB) + answerB, err := pcB.CreateAnswer(nil) + assert.NoError(t, err) + assert.NoError(t, pcB.SetLocalDescription(answerB)) + <-gatherCompletePromiseB + + pcA1Connected := make(chan struct{}, 1) + pcA1.OnICEConnectionStateChange(func(s ICEConnectionState) { + if s == ICEConnectionStateConnected { + pcA1Connected <- struct{}{} + } + }) + + incomingTracksA1 := make(chan *TrackRemote, 1) + pcA1.OnTrack(func(remote *TrackRemote, receiver *RTPReceiver) { + incomingTracksA1 <- remote + }) + + assert.NoError(t, pcA1.SetRemoteDescription(answerB)) + + <-pcA1Connected + + triggerMedia() + assert.NoError(t, dataChannelB.SendText("HelloWorld")) + + incomingTrackA1 := <-incomingTracksA1 + + pkt, _, err := incomingTrackA1.ReadRTP() + assert.NotNil(t, pkt) + assert.NoError(t, err) + + <-msgsA1 + assert.Empty(t, msgsA1) + + // ClientA2 connects to ClientB + + gatherCompletePromiseA2 := GatheringCompletePromise(pcA2) + // We can't do an ICE Restart here, since it's a different PeerConnection + offerA2, err := pcA2.CreateOffer(nil) + assert.NoError(t, err) + assert.NoError(t, pcA2.SetLocalDescription(offerA2)) + <-gatherCompletePromiseA2 + + assert.NoError(t, pcB.SetRemoteDescription(*pcA2.LocalDescription())) + + gatherCompletePromiseB = GatheringCompletePromise(pcB) + answerB, err = pcB.CreateAnswer(nil) + assert.NoError(t, err) + assert.NoError(t, pcB.SetLocalDescription(answerB)) + <-gatherCompletePromiseB + + pcA2Connected := make(chan struct{}, 1) + pcA2.OnICEConnectionStateChange(func(s ICEConnectionState) { + if s == ICEConnectionStateConnected { + pcA2Connected <- struct{}{} + } + }) + + incomingTracksA2 := make(chan *TrackRemote, 1) + pcA2.OnTrack(func(remote *TrackRemote, receiver *RTPReceiver) { + incomingTracksA2 <- remote + }) + + assert.NoError(t, pcA2.SetRemoteDescription(answerB)) + + // Wait for connection + <-pcA2Connected + + triggerMedia() + assert.NoError(t, dataChannelB.SendText("HelloWorld")) + + // Make sure A1 doesn't receive anything + assert.Empty(t, incomingTracksA1) + assert.NoError(t, incomingTrackA1.SetReadDeadline(time.Now().Add(100*time.Millisecond))) + pkt, _, err = incomingTrackA1.ReadRTP() + assert.Nil(t, pkt) + assert.Error(t, err) + + // Needed in case of `-race`?? + triggerMedia() + + // Make sure A2 receives media + incomingTrackA2 := <-incomingTracksA2 + pkt, _, err = incomingTrackA2.ReadRTP() + assert.NotNil(t, pkt) + assert.NoError(t, err) + + <-msgsA2 + assert.Empty(t, msgsA2) + assert.Empty(t, msgsA1) +} + +// Issue #1636 +func TestPeerConnection_DTLS_Restart_MediaOnly(t *testing.T) { + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + makeClient := func() (*PeerConnection, *TrackLocalStaticSample) { + pc, cliErr := NewPeerConnection(Configuration{}) + assert.NoError(t, cliErr) + + track, cliErr := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeOpus}, "audio", "test-client") + assert.NoError(t, cliErr) + + _, cliErr = pc.AddTrack(track) + assert.NoError(t, cliErr) + + return pc, track + } + + pcA1, _ := makeClient() + defer func() { _ = pcA1.Close() }() + pcA2, _ := makeClient() + defer func() { _ = pcA2.Close() }() + pcB, outputTrackB := makeClient() + defer func() { _ = pcB.Close() }() + + triggerMedia := func() { + assert.NoError(t, outputTrackB.WriteSample(media.Sample{ + Data: []byte{0xbb}, + Timestamp: time.Now(), + Duration: 20 * time.Millisecond, + })) + // Somehow, if we only send 1 packet, the OnTrack event is never fired. + time.Sleep(20 * time.Millisecond) + assert.NoError(t, outputTrackB.WriteSample(media.Sample{ + Data: []byte{0xbb}, + Timestamp: time.Now(), + Duration: 20 * time.Millisecond, + })) + } + + gatherCompletePromiseA1 := GatheringCompletePromise(pcA1) + offerA1, err := pcA1.CreateOffer(nil) + assert.NoError(t, err) + assert.NoError(t, pcA1.SetLocalDescription(offerA1)) + <-gatherCompletePromiseA1 + + assert.NoError(t, pcB.SetRemoteDescription(*pcA1.LocalDescription())) + + gatherCompletePromiseB := GatheringCompletePromise(pcB) + answerB, err := pcB.CreateAnswer(nil) + assert.NoError(t, err) + assert.NoError(t, pcB.SetLocalDescription(answerB)) + <-gatherCompletePromiseB + + pcA1Connected := make(chan struct{}, 1) + pcA1.OnICEConnectionStateChange(func(s ICEConnectionState) { + if s == ICEConnectionStateConnected { + pcA1Connected <- struct{}{} + } + }) + + incomingTracksA1 := make(chan *TrackRemote, 1) + pcA1.OnTrack(func(remote *TrackRemote, receiver *RTPReceiver) { + incomingTracksA1 <- remote + }) + + assert.NoError(t, pcA1.SetRemoteDescription(answerB)) + + <-pcA1Connected + + triggerMedia() + + incomingTrackA1 := <-incomingTracksA1 + + pkt, _, err := incomingTrackA1.ReadRTP() + assert.NotNil(t, pkt) + assert.NoError(t, err) + + // ClientA2 connects to ClientB + + gatherCompletePromiseA2 := GatheringCompletePromise(pcA2) + // We can't do an ICE Restart here, since it's a different PeerConnection + offerA2, err := pcA2.CreateOffer(nil) + assert.NoError(t, err) + assert.NoError(t, pcA2.SetLocalDescription(offerA2)) + <-gatherCompletePromiseA2 + + assert.NoError(t, pcB.SetRemoteDescription(*pcA2.LocalDescription())) + + gatherCompletePromiseB = GatheringCompletePromise(pcB) + answerB, err = pcB.CreateAnswer(nil) + assert.NoError(t, err) + assert.NoError(t, pcB.SetLocalDescription(answerB)) + <-gatherCompletePromiseB + + pcA2Connected := make(chan struct{}, 1) + pcA2.OnICEConnectionStateChange(func(s ICEConnectionState) { + if s == ICEConnectionStateConnected { + pcA2Connected <- struct{}{} + } + }) + + incomingTracksA2 := make(chan *TrackRemote, 1) + pcA2.OnTrack(func(remote *TrackRemote, receiver *RTPReceiver) { + incomingTracksA2 <- remote + }) + + assert.NoError(t, pcA2.SetRemoteDescription(answerB)) + + // Wait for connection + <-pcA2Connected + + triggerMedia() + + // Make sure A1 doesn't receive anything + assert.Empty(t, incomingTracksA1) + assert.NoError(t, incomingTrackA1.SetReadDeadline(time.Now().Add(500*time.Millisecond))) + pkt, _, err = incomingTrackA1.ReadRTP() + assert.Nil(t, pkt) + assert.Error(t, err) + + // This is needed in case of `-race`, somehow... + triggerMedia() + + // Make sure A2 receives media + incomingTrackA2 := <-incomingTracksA2 + pkt, _, err = incomingTrackA2.ReadRTP() + assert.NotNil(t, pkt) + assert.NoError(t, err) +} diff --git a/sctptransport.go b/sctptransport.go index 27e02ea65de..66bfc8efecd 100644 --- a/sctptransport.go +++ b/sctptransport.go @@ -7,9 +7,11 @@ import ( "io" "math" "sync" + "sync/atomic" "time" "github.com/pion/datachannel" + "github.com/pion/dtls/v2" "github.com/pion/logging" "github.com/pion/sctp" "github.com/pion/webrtc/v3/pkg/rtcerr" @@ -30,6 +32,8 @@ type SCTPTransport struct { // so we need a dedicated field isStarted bool + isAcceptLoopRunning uint32 // Used as a bool + // MaxMessageSize represents the maximum size of data that can be passed to // DataChannel's send() method. maxMessageSize float64 @@ -119,7 +123,7 @@ func (r *SCTPTransport) Start(remoteCaps SCTPCapabilities) error { var openedDCCount uint32 for _, d := range dataChannels { if d.ReadyState() == DataChannelStateConnecting { - err := d.open(r) + err := d.open(r, false) if err != nil { r.log.Warnf("failed to open data channel: %s", err) continue @@ -132,7 +136,45 @@ func (r *SCTPTransport) Start(remoteCaps SCTPCapabilities) error { r.dataChannelsOpened += openedDCCount r.lock.Unlock() - go r.acceptDataChannels(sctpAssociation) + go r.acceptDataChannels() + + return nil +} + +// Caller must hold lock +func (r *SCTPTransport) restart(dtlsConn *dtls.Conn) error { + sctpAssociation, err := sctp.Client(sctp.Config{ + NetConn: dtlsConn, + LoggerFactory: r.api.settingEngine.LoggerFactory, + }) + if err != nil { + return err + } + + r.sctpAssociation = sctpAssociation + + // Snapshots the DataChannels to process them asynchronously, safely. + // If a DataChannel is closed, it was most likely closed because of the + // reconnection & needs to be restarted. + dataChannelsCpy := make([]*DataChannel, len(r.dataChannels)) + copy(dataChannelsCpy, r.dataChannels) + + go func(dataChannels []*DataChannel) { + for _, d := range dataChannels { + if d.ReadyState() == DataChannelStateClosed { + err := d.open(r, true) + if err != nil { + r.log.Warnf("failed to re-open data channel: %s", err) + continue + } + } + } + }(dataChannelsCpy) + + if swapped := atomic.CompareAndSwapUint32(&r.isAcceptLoopRunning, 0, 1); swapped { + // AcceptLoop wasn't running, restart it + go r.acceptDataChannels() + } return nil } @@ -155,26 +197,44 @@ func (r *SCTPTransport) Stop() error { return nil } -func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) { +func (r *SCTPTransport) acceptDataChannels() { + atomic.StoreUint32(&r.isAcceptLoopRunning, 1) + defer atomic.StoreUint32(&r.isAcceptLoopRunning, 0) + r.lock.RLock() dataChannels := make([]*datachannel.DataChannel, 0, len(r.dataChannels)) for _, dc := range r.dataChannels { dc.mu.Lock() - isNil := dc.dataChannel == nil + scopedDataChannel := dc.dataChannel dc.mu.Unlock() - if isNil { - continue + + if scopedDataChannel != nil { + dataChannels = append(dataChannels, scopedDataChannel) } - dataChannels = append(dataChannels, dc.dataChannel) } r.lock.RUnlock() + ACCEPT: for { + // Safely access the most recent association + a := r.association() + dc, err := datachannel.Accept(a, &datachannel.Config{ LoggerFactory: r.api.settingEngine.LoggerFactory, }, dataChannels...) if err != nil { if err != io.EOF { + didRestart := r.association() != a + + if didRestart { + // During a restart, `Accept` will return an `EOF`. + // If the association is different, it means the SCTPTransport just + // performed a `restart`. In that case, just keep looping. + // + // This is safe since restarts are performed holding the `lock`. + continue + } + r.log.Errorf("Failed to accept data channel: %v", err) r.onError(err) }