From 97941d5151410d173898481b236a6ddf714e06bc Mon Sep 17 00:00:00 2001 From: Kevin Caffrey Date: Wed, 26 Jul 2023 22:54:49 -0400 Subject: [PATCH] Partial rewrite of TWCC sender The previous implementation of the TWCC sender interceptor had some inconsistencies with libwebrtc. Namely, if there were missing packets between the last packet in the previous feedback interval and the first received packet in the following feedback interval, then those packets were never included as missing in any feedback. This is an issue because libwebrtc uses data about lost packets from TWCC feedback in their congestion controller. This means that bursts of loss could go unnoticed by libwebrtc, causing the application to possibly send more data than it would if the loss was properly reported. Another minor difference was that feedback packets with a single packet are in fact valid according to the RFC, but we were only returning feedback with at least two packets. libwebrtc has a check for minimum feedback size, which has now been added to the unit tests. Just about all of the rewrite here has been ported from libwebrtc code, so it should now match the behavior fairly well. --- pkg/twcc/arrival_time_map.go | 192 ++++++++++++++++++++ pkg/twcc/arrival_time_map_test.go | 288 ++++++++++++++++++++++++++++++ pkg/twcc/sender_interceptor.go | 2 +- pkg/twcc/twcc.go | 169 +++++++++++++----- pkg/twcc/twcc_test.go | 213 ++++++---------------- 5 files changed, 660 insertions(+), 204 deletions(-) create mode 100644 pkg/twcc/arrival_time_map.go create mode 100644 pkg/twcc/arrival_time_map_test.go diff --git a/pkg/twcc/arrival_time_map.go b/pkg/twcc/arrival_time_map.go new file mode 100644 index 00000000..abca1ff5 --- /dev/null +++ b/pkg/twcc/arrival_time_map.go @@ -0,0 +1,192 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package twcc + +const ( + minCapacity = 128 + maxNumberOfPackets = 1 << 15 +) + +// packetArrivalTimeMap is adapted from Chrome's implementation of TWCC, and keeps track +// of the arrival times of packets. It is used by the TWCC interceptor to build feedback +// packets. +// See https://source.chromium.org/chromium/chromium/src/+/refs/heads/main:third_party/webrtc/modules/remote_bitrate_estimator/packet_arrival_map.h;drc=b5cd13bb6d5d157a5fbe3628b2dd1c1e106203c6 +type packetArrivalTimeMap struct { + // arrivalTimes is a circular buffer, where the packet with sequence number sn is stored + // in slot sn % len(arrivalTimes) + arrivalTimes []int64 + + // The unwrapped sequence numbers for the range of valid sequence numbers in arrivalTimes. + // beginSequenceNumber is inclusive, and endSequenceNumber is exclusive. + beginSequenceNumber, endSequenceNumber int64 +} + +// AddPacket records the fact that the packet with sequence number sequenceNumber arrived +// at arrivalTime. +func (m *packetArrivalTimeMap) AddPacket(sequenceNumber int64, arrivalTime int64) { + if m.arrivalTimes == nil { + // First packet + m.reallocate(minCapacity) + m.beginSequenceNumber = sequenceNumber + m.endSequenceNumber = sequenceNumber + 1 + m.arrivalTimes[m.index(sequenceNumber)] = arrivalTime + return + } + + if sequenceNumber >= m.beginSequenceNumber && sequenceNumber < m.endSequenceNumber { + // The packet is within the buffer, no need to resize. + m.arrivalTimes[m.index(sequenceNumber)] = arrivalTime + return + } + + if sequenceNumber < m.beginSequenceNumber { + // The packet goes before the current buffer. Expand to add packet, + // but only if it fits within the maximum number of packets. + newSize := int(m.endSequenceNumber - sequenceNumber) + if newSize > maxNumberOfPackets { + // Don't expand the buffer back for this packet, as it would remove newer received + // packets. + return + } + m.adjustToSize(newSize) + m.arrivalTimes[m.index(sequenceNumber)] = arrivalTime + m.setNotReceived(sequenceNumber+1, m.beginSequenceNumber) + m.beginSequenceNumber = sequenceNumber + return + } + + // The packet goes after the buffer. + newEndSequenceNumber := sequenceNumber + 1 + + if newEndSequenceNumber >= m.endSequenceNumber+maxNumberOfPackets { + // All old packets have to be removed. + m.beginSequenceNumber = sequenceNumber + m.endSequenceNumber = newEndSequenceNumber + m.arrivalTimes[m.index(sequenceNumber)] = arrivalTime + return + } + + if m.beginSequenceNumber < newEndSequenceNumber-maxNumberOfPackets { + // Remove oldest entries. + m.beginSequenceNumber = newEndSequenceNumber - maxNumberOfPackets + } + + m.adjustToSize(int(newEndSequenceNumber - m.beginSequenceNumber)) + + // Packets can be received out of order. If this isn't the next expected packet, + // add enough placeholders to fill the gap. + m.setNotReceived(m.endSequenceNumber, sequenceNumber) + m.endSequenceNumber = newEndSequenceNumber + m.arrivalTimes[m.index(sequenceNumber)] = arrivalTime +} + +func (m *packetArrivalTimeMap) setNotReceived(startInclusive, endExclusive int64) { + for sn := startInclusive; sn < endExclusive; sn++ { + m.arrivalTimes[m.index(sn)] = -1 + } +} + +// BeginSequenceNumber returns the first valid sequence number in the map. +func (m *packetArrivalTimeMap) BeginSequenceNumber() int64 { + return m.beginSequenceNumber +} + +// EndSequenceNumber returns the first sequence number after the last valid sequence number in the map. +func (m *packetArrivalTimeMap) EndSequenceNumber() int64 { + return m.endSequenceNumber +} + +// FindNextAtOrAfter returns the sequence number and timestamp of the first received packet that has a sequence number +// greator or equal to sequenceNumber. +func (m *packetArrivalTimeMap) FindNextAtOrAfter(sequenceNumber int64) (foundSequenceNumber int64, arrivalTime int64, ok bool) { + for sequenceNumber = m.Clamp(sequenceNumber); sequenceNumber < m.endSequenceNumber; sequenceNumber++ { + if t := m.get(sequenceNumber); t >= 0 { + return sequenceNumber, t, true + } + } + return -1, -1, false +} + +// EraseTo erases all elements from the beginning of the map until sequenceNumber. +func (m *packetArrivalTimeMap) EraseTo(sequenceNumber int64) { + if sequenceNumber < m.beginSequenceNumber { + return + } + if sequenceNumber >= m.endSequenceNumber { + // Erase all. + m.beginSequenceNumber = m.endSequenceNumber + return + } + // Remove some + m.beginSequenceNumber = sequenceNumber + m.adjustToSize(int(m.endSequenceNumber - m.beginSequenceNumber)) +} + +// RemoveOldPackets removes packets from the beginning of the map as long as they are before +// sequenceNumber and with an age older than arrivalTimeLimit. +func (m *packetArrivalTimeMap) RemoveOldPackets(sequenceNumber int64, arrivalTimeLimit int64) { + checkTo := min64(sequenceNumber, m.endSequenceNumber) + for m.beginSequenceNumber < checkTo && m.get(m.beginSequenceNumber) <= arrivalTimeLimit { + m.beginSequenceNumber++ + } + m.adjustToSize(int(m.endSequenceNumber - m.beginSequenceNumber)) +} + +// HasReceived returns whether a packet with the sequence number has been received. +func (m *packetArrivalTimeMap) HasReceived(sequenceNumber int64) bool { + return m.get(sequenceNumber) >= 0 +} + +// Clamp returns sequenceNumber clamped to [beginSequenceNumber, endSequenceNumber] +func (m *packetArrivalTimeMap) Clamp(sequenceNumber int64) int64 { + if sequenceNumber < m.beginSequenceNumber { + return m.beginSequenceNumber + } + if m.endSequenceNumber < sequenceNumber { + return m.endSequenceNumber + } + return sequenceNumber +} + +func (m *packetArrivalTimeMap) get(sequenceNumber int64) int64 { + if sequenceNumber < m.beginSequenceNumber || sequenceNumber >= m.endSequenceNumber { + return -1 + } + return m.arrivalTimes[m.index(sequenceNumber)] +} + +func (m *packetArrivalTimeMap) index(sequenceNumber int64) int { + // Sequence number might be negative, and we always guarantee that arrivalTimes + // length is a power of 2, so it's easier to use "&" instead of "%" + return int(sequenceNumber & int64(m.capacity()-1)) +} + +func (m *packetArrivalTimeMap) adjustToSize(newSize int) { + if newSize > m.capacity() { + newCapacity := m.capacity() + for newCapacity < newSize { + newCapacity *= 2 + } + m.reallocate(newCapacity) + } + if m.capacity() > max(minCapacity, newSize*4) { + newCapacity := m.capacity() + for newCapacity >= 2*max(newSize, minCapacity) { + newCapacity /= 2 + } + m.reallocate(newCapacity) + } +} + +func (m *packetArrivalTimeMap) capacity() int { + return len(m.arrivalTimes) +} + +func (m *packetArrivalTimeMap) reallocate(newCapacity int) { + newBuffer := make([]int64, newCapacity) + for sn := m.beginSequenceNumber; sn < m.endSequenceNumber; sn++ { + newBuffer[int(sn&(int64(newCapacity-1)))] = m.get(sn) + } + m.arrivalTimes = newBuffer +} diff --git a/pkg/twcc/arrival_time_map_test.go b/pkg/twcc/arrival_time_map_test.go new file mode 100644 index 00000000..fe5e9cc6 --- /dev/null +++ b/pkg/twcc/arrival_time_map_test.go @@ -0,0 +1,288 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package twcc + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestArrivalTimeMap(t *testing.T) { + t.Run("consistent when empty", func(t *testing.T) { + var m packetArrivalTimeMap + assert.Equal(t, m.BeginSequenceNumber(), m.EndSequenceNumber()) + assert.False(t, m.HasReceived(0)) + assert.Equal(t, int64(0), m.Clamp(-5)) + assert.Equal(t, int64(0), m.Clamp(5)) + }) + + t.Run("inserts first item into map", func(t *testing.T) { + var m packetArrivalTimeMap + m.AddPacket(42, 10) + assert.Equal(t, int64(42), m.BeginSequenceNumber()) + assert.Equal(t, int64(43), m.EndSequenceNumber()) + + assert.False(t, m.HasReceived(41)) + assert.True(t, m.HasReceived(42)) + assert.False(t, m.HasReceived(43)) + assert.False(t, m.HasReceived(44)) + + assert.Equal(t, int64(42), m.Clamp(-100)) + assert.Equal(t, int64(42), m.Clamp(42)) + assert.Equal(t, int64(43), m.Clamp(100)) + }) + + t.Run("inserts with gaps", func(t *testing.T) { + var m packetArrivalTimeMap + m.AddPacket(42, 0) + m.AddPacket(45, 11) + assert.Equal(t, int64(42), m.BeginSequenceNumber()) + assert.Equal(t, int64(46), m.EndSequenceNumber()) + + assert.False(t, m.HasReceived(41)) + assert.True(t, m.HasReceived(42)) + assert.False(t, m.HasReceived(43)) + assert.False(t, m.HasReceived(44)) + assert.True(t, m.HasReceived(45)) + assert.False(t, m.HasReceived(46)) + + assert.Equal(t, int64(0), m.get(42)) + assert.Less(t, m.get(43), int64(0)) + assert.Less(t, m.get(44), int64(0)) + assert.Equal(t, int64(11), m.get(45)) + + assert.Equal(t, int64(42), m.Clamp(-100)) + assert.Equal(t, int64(44), m.Clamp(44)) + assert.Equal(t, int64(46), m.Clamp(100)) + }) + + t.Run("find next at or after with gaps", func(t *testing.T) { + var m packetArrivalTimeMap + m.AddPacket(42, 0) + m.AddPacket(45, 11) + + seq, ts, ok := m.FindNextAtOrAfter(42) + assert.Equal(t, int64(42), seq) + assert.Equal(t, int64(0), ts) + assert.True(t, ok) + + seq, ts, ok = m.FindNextAtOrAfter(43) + assert.Equal(t, int64(45), seq) + assert.Equal(t, int64(11), ts) + assert.True(t, ok) + }) + + t.Run("inserts within buffer", func(t *testing.T) { + var m packetArrivalTimeMap + m.AddPacket(42, 10) + m.AddPacket(45, 11) + + m.AddPacket(43, 12) + m.AddPacket(44, 13) + + assert.False(t, m.HasReceived(41)) + assert.True(t, m.HasReceived(42)) + assert.True(t, m.HasReceived(43)) + assert.True(t, m.HasReceived(44)) + assert.True(t, m.HasReceived(45)) + assert.False(t, m.HasReceived(46)) + + assert.Equal(t, int64(10), m.get(42)) + assert.Equal(t, int64(12), m.get(43)) + assert.Equal(t, int64(13), m.get(44)) + assert.Equal(t, int64(11), m.get(45)) + }) + + t.Run("grows buffer and removes old", func(t *testing.T) { + var m packetArrivalTimeMap + + var largeSeqNum int64 = 42 + maxNumberOfPackets + m.AddPacket(42, 10) + m.AddPacket(43, 11) + m.AddPacket(44, 12) + m.AddPacket(45, 13) + m.AddPacket(largeSeqNum, 12) + + assert.Equal(t, int64(43), m.BeginSequenceNumber()) + assert.Equal(t, largeSeqNum+1, m.EndSequenceNumber()) + + assert.False(t, m.HasReceived(41)) + assert.False(t, m.HasReceived(42)) + assert.True(t, m.HasReceived(43)) + assert.True(t, m.HasReceived(44)) + assert.True(t, m.HasReceived(45)) + assert.False(t, m.HasReceived(46)) + assert.True(t, m.HasReceived(largeSeqNum)) + assert.False(t, m.HasReceived(largeSeqNum+1)) + }) + + t.Run("sequence number jump deletes all", func(t *testing.T) { + var m packetArrivalTimeMap + + var largeSeqNum int64 = 42 + 2*maxNumberOfPackets + m.AddPacket(42, 10) + m.AddPacket(largeSeqNum, 12) + + assert.Equal(t, largeSeqNum, m.BeginSequenceNumber()) + assert.Equal(t, largeSeqNum+1, m.EndSequenceNumber()) + + assert.False(t, m.HasReceived(42)) + assert.True(t, m.HasReceived(largeSeqNum)) + assert.False(t, m.HasReceived(largeSeqNum+1)) + }) + + t.Run("expands before beginning", func(t *testing.T) { + var m packetArrivalTimeMap + m.AddPacket(42, 10) + m.AddPacket(-1000, 13) + assert.Equal(t, int64(-1000), m.BeginSequenceNumber()) + assert.Equal(t, int64(43), m.EndSequenceNumber()) + + assert.False(t, m.HasReceived(-1001)) + assert.True(t, m.HasReceived(-1000)) + assert.False(t, m.HasReceived(-999)) + assert.True(t, m.HasReceived(42)) + assert.False(t, m.HasReceived(43)) + }) + + t.Run("expanding before beginning keeps received", func(t *testing.T) { + var m packetArrivalTimeMap + + var smallSeqNum int64 = 42 - 2*maxNumberOfPackets + m.AddPacket(42, 10) + m.AddPacket(smallSeqNum, 13) + + assert.Equal(t, int64(42), m.BeginSequenceNumber()) + assert.Equal(t, int64(43), m.EndSequenceNumber()) + }) + + t.Run("erase to removes elements", func(t *testing.T) { + var m packetArrivalTimeMap + m.AddPacket(42, 10) + m.AddPacket(43, 11) + m.AddPacket(44, 12) + m.AddPacket(45, 13) + + m.EraseTo(44) + + assert.Equal(t, int64(44), m.BeginSequenceNumber()) + assert.Equal(t, int64(46), m.EndSequenceNumber()) + + assert.False(t, m.HasReceived(43)) + assert.True(t, m.HasReceived(44)) + assert.True(t, m.HasReceived(45)) + assert.False(t, m.HasReceived(46)) + }) + + t.Run("erases in empty map", func(t *testing.T) { + var m packetArrivalTimeMap + + assert.Equal(t, m.BeginSequenceNumber(), m.EndSequenceNumber()) + + m.EraseTo(m.EndSequenceNumber()) + assert.Equal(t, m.BeginSequenceNumber(), m.EndSequenceNumber()) + }) + + t.Run("is tolerant to wrong arguments for erase", func(t *testing.T) { + var m packetArrivalTimeMap + m.AddPacket(42, 10) + m.AddPacket(43, 11) + + m.EraseTo(1) + + assert.Equal(t, int64(42), m.BeginSequenceNumber()) + assert.Equal(t, int64(44), m.EndSequenceNumber()) + + m.EraseTo(100) + + assert.Equal(t, int64(44), m.BeginSequenceNumber()) + assert.Equal(t, int64(44), m.EndSequenceNumber()) + }) + + t.Run("erase all remembers beginning sequence number", func(t *testing.T) { + var m packetArrivalTimeMap + m.AddPacket(42, 10) + m.AddPacket(43, 11) + m.AddPacket(44, 12) + m.AddPacket(45, 13) + + m.EraseTo(46) + m.AddPacket(50, 10) + + assert.Equal(t, int64(46), m.BeginSequenceNumber()) + assert.Equal(t, int64(51), m.EndSequenceNumber()) + + assert.False(t, m.HasReceived(45)) + assert.False(t, m.HasReceived(46)) + assert.False(t, m.HasReceived(47)) + assert.False(t, m.HasReceived(48)) + assert.False(t, m.HasReceived(49)) + assert.True(t, m.HasReceived(50)) + assert.False(t, m.HasReceived(51)) + }) + + t.Run("erase to missing sequence number", func(t *testing.T) { + var m packetArrivalTimeMap + m.AddPacket(37, 10) + m.AddPacket(39, 11) + m.AddPacket(40, 12) + m.AddPacket(41, 13) + + m.EraseTo(38) + + m.AddPacket(42, 40) + + assert.Equal(t, int64(38), m.BeginSequenceNumber()) + assert.Equal(t, int64(43), m.EndSequenceNumber()) + + assert.False(t, m.HasReceived(37)) + assert.False(t, m.HasReceived(38)) + assert.True(t, m.HasReceived(39)) + assert.True(t, m.HasReceived(40)) + assert.True(t, m.HasReceived(41)) + assert.True(t, m.HasReceived(42)) + assert.False(t, m.HasReceived(43)) + }) + + t.Run("remove old packets", func(t *testing.T) { + var m packetArrivalTimeMap + m.AddPacket(37, 10) + m.AddPacket(39, 11) + m.AddPacket(40, 12) + m.AddPacket(41, 13) + + m.RemoveOldPackets(42, 11) + + assert.Equal(t, int64(40), m.BeginSequenceNumber()) + assert.Equal(t, int64(42), m.EndSequenceNumber()) + + assert.False(t, m.HasReceived(39)) + assert.True(t, m.HasReceived(40)) + assert.True(t, m.HasReceived(41)) + assert.False(t, m.HasReceived(42)) + }) + + t.Run("shrinks buffer when necessary", func(t *testing.T) { + var m packetArrivalTimeMap + var largeSeqNum int64 = 100 + maxNumberOfPackets - 1 + m.AddPacket(100, 10) + m.AddPacket(largeSeqNum, 11) + + m.EraseTo(largeSeqNum - 1) + + assert.Equal(t, largeSeqNum-1, m.BeginSequenceNumber()) + assert.Equal(t, largeSeqNum+1, m.EndSequenceNumber()) + + assert.Equal(t, minCapacity, m.capacity()) + }) + + t.Run("find next at or after with invalid sequence", func(t *testing.T) { + var m packetArrivalTimeMap + m.AddPacket(100, 10) + + _, _, ok := m.FindNextAtOrAfter(101) + assert.False(t, ok) + }) +} diff --git a/pkg/twcc/sender_interceptor.go b/pkg/twcc/sender_interceptor.go index 8706e451..d7906fc6 100644 --- a/pkg/twcc/sender_interceptor.go +++ b/pkg/twcc/sender_interceptor.go @@ -196,7 +196,7 @@ func (s *SenderInterceptor) loop(w interceptor.RTCPWriter) { case <-ticker.C: // build and send twcc pkts := s.recorder.BuildFeedbackPacket() - if pkts == nil { + if len(pkts) == 0 { continue } if _, err := w.Write(pkts, nil); err != nil { diff --git a/pkg/twcc/twcc.go b/pkg/twcc/twcc.go index 235f1f11..da938cc8 100644 --- a/pkg/twcc/twcc.go +++ b/pkg/twcc/twcc.go @@ -7,22 +7,26 @@ package twcc import ( "math" + "github.com/pion/interceptor/internal/sequencenumber" "github.com/pion/rtcp" ) -type pktInfo struct { - sequenceNumber uint32 - arrivalTime int64 -} +const ( + packetWindowMicroseconds = 500_000 + maxMissingSequenceNumbers = 0x7FFE +) // Recorder records incoming RTP packets and their delays and creates // transport wide congestion control feedback reports as specified in // https://datatracker.ietf.org/doc/html/draft-holmer-rmcat-transport-wide-cc-extensions-01 type Recorder struct { - receivedPackets []pktInfo + arrivalTimeMap packetArrivalTimeMap - cycles uint32 - lastSequenceNumber uint16 + sequenceUnwrapper sequencenumber.Unwrapper + + // startSequenceNumber is the first sequence number that will be included in the the + // next feedback packet. + startSequenceNumber *int64 senderSSRC uint32 mediaSSRC uint32 @@ -33,68 +37,118 @@ type Recorder struct { // feedback packets. func NewRecorder(senderSSRC uint32) *Recorder { return &Recorder{ - receivedPackets: []pktInfo{}, - senderSSRC: senderSSRC, + senderSSRC: senderSSRC, } } // Record marks a packet with mediaSSRC and a transport wide sequence number sequenceNumber as received at arrivalTime. func (r *Recorder) Record(mediaSSRC uint32, sequenceNumber uint16, arrivalTime int64) { r.mediaSSRC = mediaSSRC - if sequenceNumber < 0x0fff && (r.lastSequenceNumber&0xffff) > 0xf000 { - r.cycles += 1 << 16 + + // "Unwrap" the sequence number to get a monotonically increasing sequence number that + // won't wrap around after math.MaxUint16. + unwrappedSN := r.sequenceUnwrapper.Unwrap(sequenceNumber) + r.maybeCullOldPackets(unwrappedSN, arrivalTime) + if r.startSequenceNumber == nil || unwrappedSN < *r.startSequenceNumber { + r.startSequenceNumber = &unwrappedSN } - r.receivedPackets = insertSorted(r.receivedPackets, pktInfo{ - sequenceNumber: r.cycles | uint32(sequenceNumber), - arrivalTime: arrivalTime, - }) - r.lastSequenceNumber = sequenceNumber -} -func insertSorted(list []pktInfo, element pktInfo) []pktInfo { - if len(list) == 0 { - return append(list, element) + // We are only interested in the first time a packet is received. + if r.arrivalTimeMap.HasReceived(unwrappedSN) { + return } - for i := len(list) - 1; i >= 0; i-- { - if list[i].sequenceNumber < element.sequenceNumber { - list = append(list, pktInfo{}) - copy(list[i+2:], list[i+1:]) - list[i+1] = element - return list - } - if list[i].sequenceNumber == element.sequenceNumber { - list[i] = element - return list - } + + r.arrivalTimeMap.AddPacket(unwrappedSN, arrivalTime) + + // Limit the range of sequence numbers to send feedback for. + if *r.startSequenceNumber < r.arrivalTimeMap.BeginSequenceNumber() { + sn := r.arrivalTimeMap.BeginSequenceNumber() + r.startSequenceNumber = &sn + } +} + +func (r *Recorder) maybeCullOldPackets(sequenceNumber int64, arrivalTime int64) { + if r.startSequenceNumber != nil && *r.startSequenceNumber >= r.arrivalTimeMap.EndSequenceNumber() && arrivalTime >= packetWindowMicroseconds { + r.arrivalTimeMap.RemoveOldPackets(sequenceNumber, arrivalTime-packetWindowMicroseconds) } - // element.sequenceNumber is between 0 and first ever received sequenceNumber - return append([]pktInfo{element}, list...) } // BuildFeedbackPacket creates a new RTCP packet containing a TWCC feedback report. func (r *Recorder) BuildFeedbackPacket() []rtcp.Packet { - if len(r.receivedPackets) < 2 { + if r.startSequenceNumber == nil { return nil } - feedback := newFeedback(r.senderSSRC, r.mediaSSRC, r.fbPktCnt) - r.fbPktCnt++ - feedback.setBase(uint16(r.receivedPackets[0].sequenceNumber&0xffff), r.receivedPackets[0].arrivalTime) + endSN := r.arrivalTimeMap.EndSequenceNumber() + var feedbacks []rtcp.Packet + for *r.startSequenceNumber < endSN { + feedback := r.maybeBuildFeedbackPacket(*r.startSequenceNumber, endSN) + if feedback == nil { + break + } + feedbacks = append(feedbacks, feedback.getRTCP()) + + // NOTE: we don't erase packets from the history in case they need to be resent + // after a reordering. They will be removed instead in Record when they get too + // old. + } + return feedbacks +} + +// maybeBuildFeedbackPacket builds a feedback packet starting from startSN (inclusive) until +// endSN (exclusive). +func (r *Recorder) maybeBuildFeedbackPacket(beginSeqNumInclusive, endSeqNumExclusive int64) *feedback { + // NOTE: The logic of this method is inspired by the implementation in Chrome. + // See https://source.chromium.org/chromium/chromium/src/+/refs/heads/main:third_party/webrtc/modules/remote_bitrate_estimator/remote_estimator_proxy.cc;l=276;drc=b5cd13bb6d5d157a5fbe3628b2dd1c1e106203c6 + startSN, endSN := r.arrivalTimeMap.Clamp(beginSeqNumInclusive), r.arrivalTimeMap.Clamp(endSeqNumExclusive) + + // Create feedback on demand, as we don't yet know if there are packets in the range that have been + // received. + var fb *feedback + + nextSequenceNumber := beginSeqNumInclusive - var pkts []rtcp.Packet - for _, pkt := range r.receivedPackets { - ok := feedback.addReceived(uint16(pkt.sequenceNumber&0xffff), pkt.arrivalTime) - if !ok { - pkts = append(pkts, feedback.getRTCP()) - feedback = newFeedback(r.senderSSRC, r.mediaSSRC, r.fbPktCnt) + for seq := startSN; seq < endSN; seq++ { + foundSeq, arrivalTime, ok := r.arrivalTimeMap.FindNextAtOrAfter(seq) + seq = foundSeq + if !ok || seq >= endSN { + break + } + + if fb == nil { + fb = newFeedback(r.senderSSRC, r.mediaSSRC, r.fbPktCnt) r.fbPktCnt++ - feedback.addReceived(uint16(pkt.sequenceNumber&0xffff), pkt.arrivalTime) + + // It should be possible to add seq to this new packet. + // If the difference between seq and beginSeqNumInclusive is too large, discard + // reporting too old missing packets. + baseSequenceNumber := max64(beginSeqNumInclusive, seq-maxMissingSequenceNumbers) + + // baseSequenceNumber is the expected first sequence number. This is known, + // but we may not have actually received it, so the base time should be the time + // of the first received packet in the feedback. + fb.setBase(uint16(baseSequenceNumber), arrivalTime) + + if !fb.addReceived(uint16(seq), arrivalTime) { + // Could not add a single received packet to the feedback. + // This is unexpected to actually occur, but if it does, we'll + // try again after skipping any missing packets. + // NOTE: It's fine that we already incremented fbPktCnt, as in essence + // we did actually "skip" a feedback (and this matches Chrome's behavior). + r.startSequenceNumber = &seq + return nil + } + } else if !fb.addReceived(uint16(seq), arrivalTime) { + // Could not add timestamp. Packet may be full. Return + // and try again with a fresh packet. + break } + + nextSequenceNumber = seq + 1 } - r.receivedPackets = []pktInfo{} - pkts = append(pkts, feedback.getRTCP()) - return pkts + r.startSequenceNumber = &nextSequenceNumber + return fb } type feedback struct { @@ -268,9 +322,30 @@ func (c *chunk) reset() { c.hasDifferentTypes = false } +func max(a, b int) int { + if a > b { + return a + } + return b +} + func min(a, b int) int { if a < b { return a } return b } + +func max64(a, b int64) int64 { + if a > b { + return a + } + return b +} + +func min64(a, b int64) int64 { + if a < b { + return a + } + return b +} diff --git a/pkg/twcc/twcc_test.go b/pkg/twcc/twcc_test.go index 871d9908..cf9d97b0 100644 --- a/pkg/twcc/twcc_test.go +++ b/pkg/twcc/twcc_test.go @@ -4,11 +4,11 @@ package twcc import ( - "fmt" "testing" "github.com/pion/rtcp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func rtcpToTwcc(t *testing.T, in []rtcp.Packet) []*rtcp.TransportLayerCC { @@ -373,8 +373,12 @@ func increaseTime(arrivalTime *int64, increaseAmount int64) int64 { func marshalAll(t *testing.T, pkts []rtcp.Packet) { for _, pkt := range pkts { - _, err := pkt.Marshal() + marshaled, err := pkt.Marshal() assert.NoError(t, err) + + // Chrome expects feedback packets to always be 18 bytes or more. + // https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/modules/rtp_rtcp/source/rtcp_packet/transport_feedback.cc;l=423?q=transport_feedback.cc&ss=chromium%2Fchromium%2Fsrc + assert.GreaterOrEqual(t, len(marshaled), 18) } } @@ -447,7 +451,7 @@ func TestBuildFeedbackPacket_Rolling(t *testing.T) { rtcpPackets := r.BuildFeedbackPacket() assert.Equal(t, 1, len(rtcpPackets)) - addRun(t, r, []uint16{4, 8, 9, 10}, []int64{ + addRun(t, r, []uint16{0, 4, 5, 6}, []int64{ increaseTime(&arrivalTime, rtcp.TypeTCCDeltaScaleFactor), increaseTime(&arrivalTime, rtcp.TypeTCCDeltaScaleFactor), increaseTime(&arrivalTime, rtcp.TypeTCCDeltaScaleFactor), @@ -466,7 +470,7 @@ func TestBuildFeedbackPacket_Rolling(t *testing.T) { }, SenderSSRC: 5000, MediaSSRC: 5000, - BaseSequenceNumber: 4, + BaseSequenceNumber: 0, ReferenceTime: 1, FbPktCount: 1, PacketStatusCount: 7, @@ -504,42 +508,78 @@ func TestBuildFeedbackPacket_MinInput(t *testing.T) { }) pkts := r.BuildFeedbackPacket() - assert.Nil(t, pkts) - - addRun(t, r, []uint16{1}, []int64{ - increaseTime(&arrivalTime, rtcp.TypeTCCDeltaScaleFactor), - }) - - pkts = r.BuildFeedbackPacket() assert.Equal(t, 1, len(pkts)) assert.Equal(t, &rtcp.TransportLayerCC{ Header: rtcp.Header{ - Count: rtcp.FormatTCC, - Type: rtcp.TypeTransportSpecificFeedback, - Length: 5, + Count: rtcp.FormatTCC, + Type: rtcp.TypeTransportSpecificFeedback, + Length: 5, + Padding: true, }, SenderSSRC: 5000, MediaSSRC: 5000, BaseSequenceNumber: 0, ReferenceTime: 1, FbPktCount: 0, - PacketStatusCount: 2, + PacketStatusCount: 1, PacketChunks: []rtcp.PacketStatusChunk{ &rtcp.RunLengthChunk{ PacketStatusSymbol: 1, Type: rtcp.TypeTCCRunLengthChunk, - RunLength: 2, + RunLength: 1, }, }, RecvDeltas: []*rtcp.RecvDelta{ {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: 0}, - {Type: rtcp.TypeTCCPacketReceivedSmallDelta, Delta: rtcp.TypeTCCDeltaScaleFactor}, }, }, rtcpToTwcc(t, pkts)[0]) marshalAll(t, pkts) } +func TestBuildFeedbackPacket_MissingPacketsBetweenFeedbacks(t *testing.T) { + r := NewRecorder(5000) + + // Create a run of received packets. + arrivalTime := int64(scaleFactorReferenceTime) + addRun(t, r, []uint16{0, 1, 2, 3}, []int64{ + scaleFactorReferenceTime, + increaseTime(&arrivalTime, rtcp.TypeTCCDeltaScaleFactor), + increaseTime(&arrivalTime, rtcp.TypeTCCDeltaScaleFactor), + increaseTime(&arrivalTime, rtcp.TypeTCCDeltaScaleFactor), + }) + rtcpPackets := r.BuildFeedbackPacket() + assert.Equal(t, 1, len(rtcpPackets)) + + // Now create another run of received packets, but with a gap. + addRun(t, r, []uint16{7, 8, 9}, []int64{ + increaseTime(&arrivalTime, rtcp.TypeTCCDeltaScaleFactor*256), + increaseTime(&arrivalTime, rtcp.TypeTCCDeltaScaleFactor), + increaseTime(&arrivalTime, rtcp.TypeTCCDeltaScaleFactor), + }) + rtcpPackets = r.BuildFeedbackPacket() + require.Equal(t, 1, len(rtcpPackets)) + twccPacket := rtcpToTwcc(t, rtcpPackets)[0] + assert.Equal(t, uint16(4), twccPacket.BaseSequenceNumber, "Base sequence should be one after the end of the previous feedback") + assert.Equal(t, uint16(6), twccPacket.PacketStatusCount, "Feedback should include status for both the lost and received packets") + expectedPacketChunks := []rtcp.PacketStatusChunk{ + &rtcp.StatusVectorChunk{ + Type: rtcp.TypeTCCRunLengthChunk, + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + }, + }, + } + assert.Equal(t, expectedPacketChunks, twccPacket.PacketChunks) + marshalAll(t, rtcpPackets) +} + func TestBuildFeedbackPacketCount(t *testing.T) { r := NewRecorder(5000) @@ -801,142 +841,3 @@ func TestReorderedPackets(t *testing.T) { }, unmarshalled) marshalAll(t, rtcpPackets) } - -func TestInsertSorted(t *testing.T) { - cases := []struct { - l []pktInfo - e pktInfo - expected []pktInfo - }{ - { - l: []pktInfo{}, - e: pktInfo{}, - expected: []pktInfo{{ - sequenceNumber: 0, - arrivalTime: 0, - }}, - }, - { - l: []pktInfo{ - { - sequenceNumber: 0, - arrivalTime: 0, - }, - { - sequenceNumber: 1, - arrivalTime: 0, - }, - }, - e: pktInfo{ - sequenceNumber: 2, - arrivalTime: 0, - }, - expected: []pktInfo{ - { - sequenceNumber: 0, - arrivalTime: 0, - }, - { - sequenceNumber: 1, - arrivalTime: 0, - }, - { - sequenceNumber: 2, - arrivalTime: 0, - }, - }, - }, - { - l: []pktInfo{ - { - sequenceNumber: 0, - arrivalTime: 0, - }, - { - sequenceNumber: 2, - arrivalTime: 0, - }, - }, - e: pktInfo{ - sequenceNumber: 1, - arrivalTime: 0, - }, - expected: []pktInfo{ - { - sequenceNumber: 0, - arrivalTime: 0, - }, - { - sequenceNumber: 1, - arrivalTime: 0, - }, - { - sequenceNumber: 2, - arrivalTime: 0, - }, - }, - }, - { - l: []pktInfo{ - { - sequenceNumber: 0, - arrivalTime: 0, - }, - { - sequenceNumber: 1, - arrivalTime: 0, - }, - { - sequenceNumber: 2, - arrivalTime: 0, - }, - }, - e: pktInfo{ - sequenceNumber: 1, - arrivalTime: 0, - }, - expected: []pktInfo{ - { - sequenceNumber: 0, - arrivalTime: 0, - }, - { - sequenceNumber: 1, - arrivalTime: 0, - }, - { - sequenceNumber: 2, - arrivalTime: 0, - }, - }, - }, - { - l: []pktInfo{ - { - sequenceNumber: 10, - arrivalTime: 0, - }, - }, - e: pktInfo{ - sequenceNumber: 9, - arrivalTime: 0, - }, - expected: []pktInfo{ - { - sequenceNumber: 9, - arrivalTime: 0, - }, - { - sequenceNumber: 10, - arrivalTime: 0, - }, - }, - }, - } - for i, c := range cases { - c := c - t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - assert.Equal(t, c.expected, insertSorted(c.l, c.e)) - }) - } -}