Skip to content

Commit

Permalink
Simplify moving average loss controller
Browse files Browse the repository at this point in the history
  • Loading branch information
mengelbart committed Nov 19, 2021
1 parent 62459b1 commit 221d52b
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 97 deletions.
34 changes: 17 additions & 17 deletions pkg/cc/feedback_adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestFeedbackAdapterTWCC(t *testing.T) {
for i := uint16(0); i < 22; i++ {
pkt := getPacketWithTransportCCExt(t, i)
headers = append(headers, pkt.Header)
assert.NoError(t, adapter.OnSent(t0, pkt, interceptor.Attributes{twccExtension: hdrExtID}))
assert.NoError(t, adapter.OnSent(t0, &pkt.Header, interceptor.Attributes{twccExtension: hdrExtID}))
}
results, err := adapter.OnIncomingTransportCC(&rtcp.TransportLayerCC{
Header: rtcp.Header{},
Expand Down Expand Up @@ -122,7 +122,7 @@ func TestFeedbackAdapterTWCC(t *testing.T) {
assert.Contains(t, results, types.PacketResult{
SentPacket: types.SentPacket{
SendTime: t0,
Header: headers[0],
Header: &headers[0],
},
ReceiveTime: t0.Add(time.Millisecond),
Received: true,
Expand All @@ -131,7 +131,7 @@ func TestFeedbackAdapterTWCC(t *testing.T) {
assert.Contains(t, results, types.PacketResult{
SentPacket: types.SentPacket{
SendTime: t0,
Header: headers[1],
Header: &headers[1],
},
ReceiveTime: t0.Add(101 * time.Millisecond),
Received: true,
Expand All @@ -141,7 +141,7 @@ func TestFeedbackAdapterTWCC(t *testing.T) {
assert.Contains(t, results, types.PacketResult{
SentPacket: types.SentPacket{
SendTime: t0,
Header: headers[i],
Header: &headers[i],
},
ReceiveTime: time.Time{},
Received: false,
Expand All @@ -151,7 +151,7 @@ func TestFeedbackAdapterTWCC(t *testing.T) {
assert.Contains(t, results, types.PacketResult{
SentPacket: types.SentPacket{
SendTime: t0,
Header: headers[7],
Header: &headers[7],
},
ReceiveTime: t0.Add(104 * time.Millisecond),
Received: true,
Expand All @@ -161,7 +161,7 @@ func TestFeedbackAdapterTWCC(t *testing.T) {
assert.Contains(t, results, types.PacketResult{
SentPacket: types.SentPacket{
SendTime: t0,
Header: headers[i],
Header: &headers[i],
},
ReceiveTime: time.Time{},
Received: false,
Expand All @@ -171,7 +171,7 @@ func TestFeedbackAdapterTWCC(t *testing.T) {
assert.Contains(t, results, types.PacketResult{
SentPacket: types.SentPacket{
SendTime: t0,
Header: headers[21],
Header: &headers[21],
},
ReceiveTime: t0.Add(105 * time.Millisecond),
Received: true,
Expand Down Expand Up @@ -221,8 +221,8 @@ func TestFeedbackAdapterTWCC(t *testing.T) {
adapter := NewFeedbackAdapter()
pkt65535 := getPacketWithTransportCCExt(t, 65535)
pkt0 := getPacketWithTransportCCExt(t, 0)
assert.NoError(t, adapter.OnSent(t0, pkt65535, interceptor.Attributes{twccExtension: hdrExtID}))
assert.NoError(t, adapter.OnSent(t0, pkt0, interceptor.Attributes{twccExtension: hdrExtID}))
assert.NoError(t, adapter.OnSent(t0, &pkt65535.Header, interceptor.Attributes{twccExtension: hdrExtID}))
assert.NoError(t, adapter.OnSent(t0, &pkt0.Header, interceptor.Attributes{twccExtension: hdrExtID}))

results, err := adapter.OnIncomingTransportCC(&rtcp.TransportLayerCC{
Header: rtcp.Header{},
Expand Down Expand Up @@ -266,15 +266,15 @@ func TestFeedbackAdapterTWCC(t *testing.T) {
assert.Contains(t, results, types.PacketResult{
SentPacket: types.SentPacket{
SendTime: t0,
Header: pkt65535.Header,
Header: &pkt65535.Header,
},
ReceiveTime: t0.Add(1 * time.Millisecond),
Received: true,
})
assert.Contains(t, results, types.PacketResult{
SentPacket: types.SentPacket{
SendTime: t0,
Header: pkt0.Header,
Header: &pkt0.Header,
},
ReceiveTime: t0.Add(2 * time.Millisecond),
Received: true,
Expand All @@ -288,7 +288,7 @@ func TestFeedbackAdapterTWCC(t *testing.T) {
for i := uint16(0); i < 8; i++ {
pkt := getPacketWithTransportCCExt(t, i)
headers = append(headers, pkt.Header)
assert.NoError(t, adapter.OnSent(t0, pkt, interceptor.Attributes{twccExtension: hdrExtID}))
assert.NoError(t, adapter.OnSent(t0, &pkt.Header, interceptor.Attributes{twccExtension: hdrExtID}))
}

results, err := adapter.OnIncomingTransportCC(&rtcp.TransportLayerCC{
Expand Down Expand Up @@ -336,7 +336,7 @@ func TestFeedbackAdapterTWCC(t *testing.T) {
assert.Contains(t, results, types.PacketResult{
SentPacket: types.SentPacket{
SendTime: t0,
Header: headers[i],
Header: &headers[i],
},
ReceiveTime: t0.Add(time.Duration(i+1) * time.Millisecond),
Received: true,
Expand All @@ -349,7 +349,7 @@ func TestFeedbackAdapterTWCC(t *testing.T) {
t0 := time.Time{}
for i := uint16(0); i < 20; i++ {
pkt := getPacketWithTransportCCExt(t, i)
assert.NoError(t, adapter.OnSent(t0, pkt, interceptor.Attributes{twccExtension: hdrExtID}))
assert.NoError(t, adapter.OnSent(t0, &pkt.Header, interceptor.Attributes{twccExtension: hdrExtID}))
}
packets, err := adapter.OnIncomingTransportCC(&rtcp.TransportLayerCC{
Header: rtcp.Header{},
Expand Down Expand Up @@ -390,7 +390,7 @@ func TestFeedbackAdapterTWCC(t *testing.T) {
t0 := time.Time{}
for i := uint16(0); i < 20; i++ {
pkt := getPacketWithTransportCCExt(t, i)
assert.NoError(t, adapter.OnSent(t0, pkt, interceptor.Attributes{twccExtension: hdrExtID}))
assert.NoError(t, adapter.OnSent(t0, &pkt.Header, interceptor.Attributes{twccExtension: hdrExtID}))
}
packets, err := adapter.OnIncomingTransportCC(&rtcp.TransportLayerCC{
Header: rtcp.Header{},
Expand Down Expand Up @@ -447,7 +447,7 @@ func TestFeedbackAdapterTWCC(t *testing.T) {
t0 := time.Time{}
for i := uint16(0); i < 20; i++ {
pkt := getPacketWithTransportCCExt(t, i)
assert.NoError(t, adapter.OnSent(t0, pkt, interceptor.Attributes{twccExtension: hdrExtID}))
assert.NoError(t, adapter.OnSent(t0, &pkt.Header, interceptor.Attributes{twccExtension: hdrExtID}))
}

packets, err := adapter.OnIncomingTransportCC(&rtcp.TransportLayerCC{
Expand Down Expand Up @@ -514,7 +514,7 @@ func TestFeedbackAdapterTWCC(t *testing.T) {
t0 := time.Time{}
for i := uint16(1008); i < 1030; i++ {
pkt := getPacketWithTransportCCExt(t, i)
assert.NoError(t, adapter.OnSent(t0, pkt, interceptor.Attributes{twccExtension: hdrExtID}))
assert.NoError(t, adapter.OnSent(t0, &pkt.Header, interceptor.Attributes{twccExtension: hdrExtID}))
}

assert.NotPanics(t, func() {
Expand Down
6 changes: 3 additions & 3 deletions pkg/cc/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type BandwidthEstimatorFactory func() BandwidthEstimator
// BandwidthEstimator is the interface of a bandwidth estimator
type BandwidthEstimator interface {
OnPacketSent(ts time.Time, sizeInBytes int)
OnFeedback(time.Time, []types.PacketResult)
OnFeedback([]types.PacketResult)
GetBandwidthEstimation() types.DataRate
}

Expand Down Expand Up @@ -220,15 +220,15 @@ func (c *ControllerInterceptor) loop() {
if err != nil {
// TODO(mathis): handle error
}
c.OnFeedback(feedback.ts, packetResult)
c.OnFeedback(packetResult)
c.pacer.SetTargetBitrate(c.GetBandwidthEstimation())

case feedback := <-c.rfc8888FeedbackChan:
packetResult, err := c.OnIncomingRFC8888(feedback.RawPacket)
if err != nil {
// TODO(mathis): handle error
}
c.OnFeedback(feedback.ts, packetResult)
c.OnFeedback(packetResult)
c.pacer.SetTargetBitrate(c.GetBandwidthEstimation())
}
}
Expand Down
93 changes: 18 additions & 75 deletions pkg/gcc/loss_based_bwe.go
Original file line number Diff line number Diff line change
@@ -1,107 +1,50 @@
package gcc

import (
"fmt"
"math"
"time"

"github.com/pion/interceptor/internal/types"
)

type lossBasedBWEConfig struct {
lossWindow time.Duration
maxLossWindow time.Duration
maxAcknowledgedRateWindow time.Duration
}

type lossBasedBandwidthEstimator struct {
config lossBasedBWEConfig

bitrate types.DataRate
averageLoss float64
maxAverageLoss float64
averageLossMax float64
lastLossReport time.Time

maxAcknowledgedRate types.DataRate
lastAcknowledgedRateReport time.Time
bitrate types.DataRate
averageLoss float64
inertia float64
decay float64
}

func newLossBasedBWE() *lossBasedBandwidthEstimator {
return &lossBasedBandwidthEstimator{
config: lossBasedBWEConfig{
lossWindow: 800 * time.Millisecond,
maxLossWindow: 800 * time.Millisecond,
maxAcknowledgedRateWindow: 800 * time.Millisecond,
},
bitrate: 0,

averageLoss: 0,
maxAverageLoss: 0,
averageLossMax: 0,
lastLossReport: time.Time{},

maxAcknowledgedRate: 0,
lastAcknowledgedRateReport: time.Time{},
inertia: 0.5,
decay: 0.5,
bitrate: 0,
averageLoss: 0,
}
}

func (e *lossBasedBandwidthEstimator) getEstimate(wantedRate types.DataRate) types.DataRate {
if e.bitrate == 0 {
if e.bitrate <= 0 {
e.bitrate = wantedRate
}

fmt.Printf("maxAverageLoss=%v\n", e.maxAverageLoss)
// Naive implementation using constants from IETF Draft
// TODO(mathis): Make this more smart and configurable. (Smart here means
// don't decrease too often and such things, see libwebrtc)
if e.maxAverageLoss < 0.02 {
e.bitrate = types.DataRate(1.05 * float64(e.bitrate))
} else if e.maxAverageLoss > 0.1 {
e.bitrate = types.DataRate(float64(e.bitrate) * (1 - 0.5*e.maxAverageLoss))
}

return e.bitrate
}

func (e *lossBasedBandwidthEstimator) updateLossStats(now time.Time, results []types.PacketResult) {
func (e *lossBasedBandwidthEstimator) updateLossStats(results []types.PacketResult) {
packetsLost := 0
for _, p := range results {
if !p.Received {
packetsLost++
}
}
fmt.Printf("lost %v packets\n", packetsLost)

lossRatio := float64(packetsLost) / float64(len(results))
delta := deltaOrDefault(e.lastLossReport, now, time.Second)
e.lastLossReport = now
e.averageLoss = e.inertia*lossRatio + e.decay*(1-e.inertia)*e.averageLoss

e.averageLoss += exponentialUpdate(delta, e.config.lossWindow) * (lossRatio - e.averageLoss)
if e.averageLoss > e.maxAverageLoss {
e.maxAverageLoss = e.averageLoss
} else {
e.maxAverageLoss += exponentialUpdate(delta, e.config.maxLossWindow) * (e.averageLoss - e.maxAverageLoss)
}
}

func (e *lossBasedBandwidthEstimator) updateAcknowledgedBitrate(now time.Time, acknowledgedRate types.DataRate) {
delta := deltaOrDefault(e.lastAcknowledgedRateReport, now, time.Second)
if acknowledgedRate > e.maxAcknowledgedRate {
e.maxAcknowledgedRate = acknowledgedRate
} else {
// TODO(mathis): Double check these type conversions
e.maxAcknowledgedRate -= types.DataRate(exponentialUpdate(delta, e.config.maxAcknowledgedRateWindow) * float64(e.maxAcknowledgedRate-acknowledgedRate))
}
}

func exponentialUpdate(delta, window time.Duration) float64 {
return (1.0 - math.Exp(float64(delta.Milliseconds())/-float64(window.Milliseconds())))
}

func deltaOrDefault(last, now time.Time, defaultVal time.Duration) time.Duration {
if last.IsZero() {
return defaultVal
// Naive implementation using constants from IETF Draft
// TODO(mathis): Make this more smart and configurable. (Smart here means
// don't decrease too often and such things, see libwebrtc)
if e.averageLoss < 0.02 {
e.bitrate = types.DataRate(1.05 * float64(e.bitrate))
} else if e.averageLoss > 0.1 {
e.bitrate = types.DataRate(float64(e.bitrate) * (1 - 0.5*e.averageLoss))
}
return now.Sub(last)
}
4 changes: 2 additions & 2 deletions pkg/gcc/send_side_bwe.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ func (g *SendSideBandwidthEstimator) OnPacketSent(ts time.Time, sizeInBytes int)
}

// OnFeedback updates the GCC statistics from the incoming feedback.
func (g *SendSideBandwidthEstimator) OnFeedback(ts time.Time, feedback []types.PacketResult) {
g.lossBased.updateLossStats(ts, feedback)
func (g *SendSideBandwidthEstimator) OnFeedback(feedback []types.PacketResult) {
g.lossBased.updateLossStats(feedback)
}

// GetBandwidthEstimation returns the estimated bandwidth available
Expand Down

0 comments on commit 221d52b

Please sign in to comment.