diff --git a/association.go b/association.go index b632fbde..aba18ea8 100644 --- a/association.go +++ b/association.go @@ -137,6 +137,8 @@ func getAssociationStateString(a uint32) string { // // Note: No "CLOSED" state is illustrated since if a // association is "CLOSED" its TCB SHOULD be removed. +// Note: By nature of an Association being constructed with one net.Conn, +// it is not a multi-home supporting implementation of SCTP. type Association struct { bytesReceived uint64 bytesSent uint64 @@ -305,11 +307,17 @@ func createAssociation(config Config) *Association { tsn := globalMathRandomGenerator.Uint32() a := &Association{ - netConn: config.NetConn, - maxReceiveBufferSize: maxReceiveBufferSize, - maxMessageSize: maxMessageSize, + netConn: config.NetConn, + maxReceiveBufferSize: maxReceiveBufferSize, + maxMessageSize: maxMessageSize, + + // These two max values have us not need to follow + // 5.1.1 where this peer may be incapable of supporting + // the requested amount of outbound streams from the other + // peer. myMaxNumOutboundStreams: math.MaxUint16, myMaxNumInboundStreams: math.MaxUint16, + payloadQueue: newPayloadQueue(), inflightQueue: newPayloadQueue(), pendingQueue: newPendingQueue(), @@ -480,8 +488,11 @@ func (a *Association) Close() error { <-a.readLoopCloseCh a.log.Debugf("[%s] association closed", a.name) + a.log.Debugf("[%s] stats nPackets (in) : %d", a.name, a.stats.getNumPacketsReceived()) + a.log.Debugf("[%s] stats nPackets (out) : %d", a.name, a.stats.getNumPacketsSent()) a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs()) - a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKs()) + a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKsReceived()) + a.log.Debugf("[%s] stats nSACKs (out) : %d\n", a.name, a.stats.getNumSACKsSent()) a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts()) a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts()) a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans()) @@ -551,7 +562,7 @@ func (a *Association) readLoop() { a.log.Debugf("[%s] association closed", a.name) a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs()) - a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKs()) + a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKsReceived()) a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts()) a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts()) a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans()) @@ -600,6 +611,7 @@ loop: break loop } atomic.AddUint64(&a.bytesSent, uint64(len(raw))) + a.stats.incPacketsSent() } if !ok { @@ -674,7 +686,7 @@ func (a *Association) handleInbound(raw []byte) error { return nil } - a.handleChunkStart() + a.handleChunksStart() for _, c := range p.chunks { if err := a.handleChunk(p, c); err != nil { @@ -682,7 +694,7 @@ func (a *Association) handleInbound(raw []byte) error { } } - a.handleChunkEnd() + a.handleChunksEnd() return nil } @@ -829,6 +841,7 @@ func (a *Association) gatherOutboundSackPackets(rawPackets [][]byte) [][]byte { if a.ackState == ackStateImmediate { a.ackState = ackStateIdle sack := a.createSelectiveAckChunk() + a.stats.incSACKsSent() a.log.Debugf("[%s] sending SACK: %s", a.name, sack) raw, err := a.marshalPacket(a.createPacket([]chunk{sack})) if err != nil { @@ -1122,7 +1135,10 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { return nil, fmt.Errorf("%w: %s", ErrHandleInitState, getAssociationStateString(state)) } - // Should we be setting any of these permanently until we've ACKed further? + // NOTE: Setting these prior to a reception of a COOKIE ECHO chunk containing + // our cookie is not compliant with https://www.rfc-editor.org/rfc/rfc9260#section-5.1-2.2.3. + // It makes us more vulnerable to resource attacks, albeit minimally so. + // https://www.rfc-editor.org/rfc/rfc9260#sec_handle_stream_parameters a.myMaxNumInboundStreams = min16(i.numInboundStreams, a.myMaxNumInboundStreams) a.myMaxNumOutboundStreams = min16(i.numOutboundStreams, a.myMaxNumOutboundStreams) a.peerVerificationTag = i.initiateTag @@ -1168,6 +1184,8 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { if a.myCookie == nil { var err error + // NOTE: This generation process is not compliant with + // 5.1.3. Generating State Cookie (https://www.rfc-editor.org/rfc/rfc4960#section-5.1.3) if a.myCookie, err = newRandomStateCookie(); err != nil { return nil, err } @@ -1307,6 +1325,8 @@ func (a *Association) handleCookieEcho(c *chunkCookieEcho) []*packet { return nil } + // RFC wise, these do not seem to belong here, but removing them + // causes TestCookieEchoRetransmission to break a.t1Init.stop() a.storedInit = nil @@ -1314,6 +1334,7 @@ func (a *Association) handleCookieEcho(c *chunkCookieEcho) []*packet { a.storedCookieEcho = nil a.setState(established) + // Note: This is a future place where the user could be notified (COMMUNICATION UP) a.handshakeCompletedCh <- nil } @@ -1342,6 +1363,7 @@ func (a *Association) handleCookieAck() { a.storedCookieEcho = nil a.setState(established) + // Note: This is a future place where the user could be notified (COMMUNICATION UP) a.handshakeCompletedCh <- nil } @@ -1355,9 +1377,9 @@ func (a *Association) handleData(d *chunkPayloadData) []*packet { if canPush { s := a.getOrCreateStream(d.streamIdentifier, true, PayloadTypeUnknown) if s == nil { - // silentely discard the data. (sender will retry on T3-rtx timeout) + // silently discard the data. (sender will retry on T3-rtx timeout) // see pion/sctp#30 - a.log.Debugf("discard %d", d.streamSequenceNumber) + a.log.Debugf("[%s] discard %d", a.name, d.streamSequenceNumber) return nil } @@ -1722,7 +1744,7 @@ func (a *Association) handleSack(d *chunkSelectiveAck) error { return nil } - a.stats.incSACKs() + a.stats.incSACKsReceived() if sna32GT(a.cumulativeTSNAckPoint, d.cumulativeTSNAck) { // RFC 4960 sec 6.2.1. Processing a Received SACK @@ -2381,15 +2403,17 @@ func pack(p *packet) []*packet { return []*packet{p} } -func (a *Association) handleChunkStart() { +func (a *Association) handleChunksStart() { a.lock.Lock() defer a.lock.Unlock() + a.stats.incPacketsReceived() + a.delayedAckTriggered = false a.immediateAckTriggered = false } -func (a *Association) handleChunkEnd() { +func (a *Association) handleChunksEnd() { a.lock.Lock() defer a.lock.Unlock() @@ -2412,13 +2436,18 @@ func (a *Association) handleChunk(p *packet, c chunk) error { var err error if _, err = c.check(); err != nil { - a.log.Errorf("[ %s ] failed validating chunk: %s ", a.name, err) + a.log.Errorf("[%s] failed validating chunk: %s ", a.name, err) return nil } isAbort := false switch c := c.(type) { + // Note: We do not do the following for chunkInit, chunkInitAck, and chunkCookieEcho: + // If an endpoint receives an INIT, INIT ACK, or COOKIE ECHO chunk but decides not to establish the + // new association due to missing mandatory parameters in the received INIT or INIT ACK chunk, invalid + // parameter values, or lack of local resources, it SHOULD respond with an ABORT chunk. + case *chunkInit: packets, err = a.handleInit(p, c) @@ -2436,6 +2465,7 @@ func (a *Association) handleChunk(p *packet, c chunk) error { } a.log.Debugf("[%s] Error chunk, with following errors: %s", a.name, errStr) + // Note: chunkHeartbeatAck not handled? case *chunkHeartbeat: packets = a.handleHeartbeat(c) diff --git a/association_stats.go b/association_stats.go index 60883c47..0e4e581b 100644 --- a/association_stats.go +++ b/association_stats.go @@ -8,11 +8,30 @@ import ( ) type associationStats struct { - nDATAs uint64 - nSACKs uint64 - nT3Timeouts uint64 - nAckTimeouts uint64 - nFastRetrans uint64 + nPacketsReceived uint64 + nPacketsSent uint64 + nDATAs uint64 + nSACKsReceived uint64 + nSACKsSent uint64 + nT3Timeouts uint64 + nAckTimeouts uint64 + nFastRetrans uint64 +} + +func (s *associationStats) incPacketsReceived() { + atomic.AddUint64(&s.nPacketsReceived, 1) +} + +func (s *associationStats) getNumPacketsReceived() uint64 { + return atomic.LoadUint64(&s.nPacketsReceived) +} + +func (s *associationStats) incPacketsSent() { + atomic.AddUint64(&s.nPacketsSent, 1) +} + +func (s *associationStats) getNumPacketsSent() uint64 { + return atomic.LoadUint64(&s.nPacketsSent) } func (s *associationStats) incDATAs() { @@ -23,12 +42,20 @@ func (s *associationStats) getNumDATAs() uint64 { return atomic.LoadUint64(&s.nDATAs) } -func (s *associationStats) incSACKs() { - atomic.AddUint64(&s.nSACKs, 1) +func (s *associationStats) incSACKsReceived() { + atomic.AddUint64(&s.nSACKsReceived, 1) +} + +func (s *associationStats) getNumSACKsReceived() uint64 { + return atomic.LoadUint64(&s.nSACKsReceived) +} + +func (s *associationStats) incSACKsSent() { + atomic.AddUint64(&s.nSACKsSent, 1) } -func (s *associationStats) getNumSACKs() uint64 { - return atomic.LoadUint64(&s.nSACKs) +func (s *associationStats) getNumSACKsSent() uint64 { + return atomic.LoadUint64(&s.nSACKsSent) } func (s *associationStats) incT3Timeouts() { @@ -56,8 +83,11 @@ func (s *associationStats) getNumFastRetrans() uint64 { } func (s *associationStats) reset() { + atomic.StoreUint64(&s.nPacketsReceived, 0) + atomic.StoreUint64(&s.nPacketsSent, 0) atomic.StoreUint64(&s.nDATAs, 0) - atomic.StoreUint64(&s.nSACKs, 0) + atomic.StoreUint64(&s.nSACKsReceived, 0) + atomic.StoreUint64(&s.nSACKsSent, 0) atomic.StoreUint64(&s.nT3Timeouts, 0) atomic.StoreUint64(&s.nAckTimeouts, 0) atomic.StoreUint64(&s.nFastRetrans, 0) diff --git a/association_test.go b/association_test.go index a7636e1b..ed53ed71 100644 --- a/association_test.go +++ b/association_test.go @@ -413,8 +413,8 @@ func establishSessionPair(br *test.Bridge, a0, a1 *Association, si uint16) (*Str } func TestAssocReliable(t *testing.T) { - // sbuf - small enogh not to be fragmented - // large enobh not to be bundled + // sbuf - small enough not to be fragmented + // large enough not to be bundled sbuf := make([]byte, 1000) for i := 0; i < len(sbuf); i++ { sbuf[i] = byte(i & 0xff) @@ -422,8 +422,8 @@ func TestAssocReliable(t *testing.T) { rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(sbuf), func(i, j int) { sbuf[i], sbuf[j] = sbuf[j], sbuf[i] }) - // sbufL - large enogh to be fragmented into two chunks and each chunks are - // large enobh not to be bundled + // sbufL - large enough to be fragmented into two chunks and each chunks are + // large enough not to be bundled sbufL := make([]byte, 2000) for i := 0; i < len(sbufL); i++ { sbufL[i] = byte(i & 0xff) @@ -823,8 +823,8 @@ func TestAssocReliable(t *testing.T) { func TestAssocUnreliable(t *testing.T) { // sbuf1, sbuf2: - // large enogh to be fragmented into two chunks and each chunks are - // large enobh not to be bundled + // large enough to be fragmented into two chunks and each chunks are + // large enough not to be bundled sbuf1 := make([]byte, 2000) sbuf2 := make([]byte, 2000) for i := 0; i < len(sbuf1); i++ { @@ -838,8 +838,8 @@ func TestAssocUnreliable(t *testing.T) { rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(sbuf2), func(i, j int) { sbuf2[i], sbuf2[j] = sbuf2[j], sbuf2[i] }) - // sbuf - small enogh not to be fragmented - // large enobh not to be bundled + // sbuf - small enough not to be fragmented + // large enough not to be bundled sbuf := make([]byte, 1000) for i := 0; i < len(sbuf); i++ { sbuf[i] = byte(i & 0xff) @@ -1754,7 +1754,7 @@ func TestAssocT3RtxTimer(t *testing.T) { } func TestAssocCongestionControl(t *testing.T) { - // sbuf - large enobh not to be bundled + // sbuf - large enough not to be bundled sbuf := make([]byte, 1000) for i := 0; i < len(sbuf); i++ { sbuf[i] = byte(i & 0xcc) @@ -1825,7 +1825,7 @@ func TestAssocCongestionControl(t *testing.T) { assert.False(t, inFastRecovery, "should not be in fast-recovery") t.Logf("nDATAs : %d\n", a1.stats.getNumDATAs()) - t.Logf("nSACKs : %d\n", a0.stats.getNumSACKs()) + t.Logf("nSACKs : %d\n", a0.stats.getNumSACKsReceived()) t.Logf("nAckTimeouts: %d\n", a1.stats.getNumAckTimeouts()) t.Logf("nFastRetrans: %d\n", a0.stats.getNumFastRetrans()) @@ -1909,11 +1909,11 @@ func TestAssocCongestionControl(t *testing.T) { assert.Equal(t, 0, s1.getNumBytesInReassemblyQueue(), "reassembly queue should be empty") t.Logf("nDATAs : %d\n", a1.stats.getNumDATAs()) - t.Logf("nSACKs : %d\n", a0.stats.getNumSACKs()) + t.Logf("nSACKs : %d\n", a0.stats.getNumSACKsReceived()) t.Logf("nT3Timeouts : %d\n", a0.stats.getNumT3Timeouts()) assert.Equal(t, uint64(nPacketsToSend), a1.stats.getNumDATAs(), "packet count mismatch") - assert.True(t, a0.stats.getNumSACKs() <= nPacketsToSend/2, "too many sacks") + assert.True(t, a0.stats.getNumSACKsReceived() <= nPacketsToSend/2, "too many sacks") assert.Equal(t, uint64(0), a0.stats.getNumT3Timeouts(), "should be no retransmit") closeAssociationPair(br, a0, a1) @@ -2004,7 +2004,7 @@ func TestAssocCongestionControl(t *testing.T) { assert.Equal(t, 0, s1.getNumBytesInReassemblyQueue(), "reassembly queue should be empty") t.Logf("nDATAs : %d\n", a1.stats.getNumDATAs()) - t.Logf("nSACKs : %d\n", a0.stats.getNumSACKs()) + t.Logf("nSACKs : %d\n", a0.stats.getNumSACKsReceived()) t.Logf("nAckTimeouts: %d\n", a1.stats.getNumAckTimeouts()) closeAssociationPair(br, a0, a1) @@ -2083,11 +2083,11 @@ func TestAssocDelayedAck(t *testing.T) { assert.Equal(t, 0, s1.getNumBytesInReassemblyQueue(), "reassembly queue should be empty") t.Logf("nDATAs : %d\n", a1.stats.getNumDATAs()) - t.Logf("nSACKs : %d\n", a0.stats.getNumSACKs()) + t.Logf("nSACKs : %d\n", a0.stats.getNumSACKsReceived()) t.Logf("nAckTimeouts: %d\n", a1.stats.getNumAckTimeouts()) assert.Equal(t, uint64(1), a1.stats.getNumDATAs(), "DATA chunk count mismatch") - assert.Equal(t, a0.stats.getNumSACKs(), a1.stats.getNumDATAs(), "sack count should be equal to the number of data chunks") + assert.Equal(t, a0.stats.getNumSACKsReceived(), a1.stats.getNumDATAs(), "sack count should be equal to the number of data chunks") assert.Equal(t, uint64(1), a1.stats.getNumAckTimeouts(), "ackTimeout count mismatch") assert.Equal(t, uint64(0), a0.stats.getNumT3Timeouts(), "should be no retransmit") diff --git a/go.mod b/go.mod index 30f238aa..e59082c0 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ require ( github.com/pion/logging v0.2.2 github.com/pion/randutil v0.1.0 github.com/pion/transport/v3 v3.0.1 - github.com/stretchr/testify v1.8.4 + github.com/stretchr/testify v1.9.0 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect ) diff --git a/go.sum b/go.sum index 3de44250..dec487ee 100644 --- a/go.sum +++ b/go.sum @@ -17,10 +17,12 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= diff --git a/rtx_timer.go b/rtx_timer.go index 823186c3..42848abc 100644 --- a/rtx_timer.go +++ b/rtx_timer.go @@ -10,14 +10,28 @@ import ( ) const ( - rtoInitial float64 = 1.0 * 1000 // msec - rtoMin float64 = 1.0 * 1000 // msec - defaultRTOMax float64 = 60.0 * 1000 // msec - rtoAlpha float64 = 0.125 - rtoBeta float64 = 0.25 - maxInitRetrans uint = 8 - pathMaxRetrans uint = 5 - noMaxRetrans uint = 0 + // RTO.Initial in msec + rtoInitial float64 = 1.0 * 1000 + + // RTO.Min in msec + rtoMin float64 = 1.0 * 1000 + + // RTO.Max in msec + defaultRTOMax float64 = 60.0 * 1000 + + // RTO.Alpha + rtoAlpha float64 = 0.125 + + // RTO.Beta + rtoBeta float64 = 0.25 + + // Max.Init.Retransmits: + maxInitRetrans uint = 8 + + // Path.Max.Retrans + pathMaxRetrans uint = 5 + + noMaxRetrans uint = 0 ) // rtoManager manages Rtx timeout values.