diff --git a/gbn/gbn_client.go b/gbn/gbn_client.go index 6cb43b2..b673f98 100644 --- a/gbn/gbn_client.go +++ b/gbn/gbn_client.go @@ -21,12 +21,7 @@ func NewClientConn(ctx context.Context, n uint8, sendFunc sendBytesFunc, math.MaxUint8) } - conn := newGoBackNConn(ctx, sendFunc, receiveFunc, false, n) - - // Apply functional options - for _, o := range opts { - o(conn) - } + conn := newGoBackNConn(ctx, sendFunc, receiveFunc, false, n, opts...) if err := conn.clientHandshake(); err != nil { if err := conn.Close(); err != nil { @@ -41,12 +36,16 @@ func NewClientConn(ctx context.Context, n uint8, sendFunc sendBytesFunc, // clientHandshake initiates the client side GBN handshake. // The handshake sequence from the client side is as follows: -// 1. The client sends SYN to the server along with the N value that the +// +// 1. The client sends SYN to the server along with the N value that the // client wishes to use for the connection. -// 2. The client then waits for the server to respond with SYN. -// 3a. If the client receives SYN from the server then the client sends back -// SYNACK. -// 3b. If the client does not receive SYN from the server within a given +// +// 2. The client then waits for the server to respond with SYN. +// +// 3. 3.1 If the client receives SYN from the server then the client sends +// back a SYNACK. +// +// 3.2 If the client does not receive SYN from the server within a given // timeout, then the client restarts the handshake from step 1. func (g *GoBackNConn) clientHandshake() error { // Spin off the recv function in a goroutine so that we can use diff --git a/gbn/gbn_conn.go b/gbn/gbn_conn.go index 1cd8e36..699f1f8 100644 --- a/gbn/gbn_conn.go +++ b/gbn/gbn_conn.go @@ -114,20 +114,17 @@ type GoBackNConn struct { // newGoBackNConn creates a GoBackNConn instance with all the members which // are common between client and server initialised. func newGoBackNConn(ctx context.Context, sendFunc sendBytesFunc, - recvFunc recvBytesFunc, isServer bool, n uint8) *GoBackNConn { + recvFunc recvBytesFunc, isServer bool, n uint8, + opts ...Option) *GoBackNConn { ctxc, cancel := context.WithCancel(ctx) - return &GoBackNConn{ - n: n, - s: n + 1, + gbn := &GoBackNConn{ resendTimeout: defaultResendTimeout, recvFromStream: recvFunc, sendToStream: sendFunc, - recvDataChan: make(chan *PacketData, n), sendDataChan: make(chan *PacketData), isServer: isServer, - sendQueue: newQueue(n+1, defaultHandshakeTimeout), handshakeTimeout: defaultHandshakeTimeout, recvTimeout: DefaultRecvTimeout, sendTimeout: DefaultSendTimeout, @@ -138,6 +135,14 @@ func newGoBackNConn(ctx context.Context, sendFunc sendBytesFunc, cancel: cancel, quit: make(chan struct{}), } + + for _, o := range opts { + o(gbn) + } + + gbn.setN(n) + + return gbn } // setN sets the current N to use. This _must_ be set before the handshake is @@ -146,7 +151,13 @@ func (g *GoBackNConn) setN(n uint8) { g.n = n g.s = n + 1 g.recvDataChan = make(chan *PacketData, n) - g.sendQueue = newQueue(n+1, defaultHandshakeTimeout) + g.sendQueue = newQueue(&queueConfig{ + s: g.s, + resendTimeout: g.resendTimeout, + sendPkt: func(packet *PacketData) error { + return g.sendPacket(g.ctx, packet) + }, + }) } // SetSendTimeout sets the timeout used in the Send function. @@ -348,6 +359,8 @@ func (g *GoBackNConn) Close() error { // initialisation. g.cancel() + g.sendQueue.stop() + g.wg.Wait() if g.pingTicker != nil { @@ -387,12 +400,7 @@ func (g *GoBackNConn) sendPacket(ctx context.Context, msg Message) error { func (g *GoBackNConn) sendPacketsForever() error { // resendQueue re-sends the current contents of the queue. resendQueue := func() error { - err := g.sendQueue.resend( - g.resendTimeout, g.quit, - func(packet *PacketData) error { - return g.sendPacket(g.ctx, packet) - }, - ) + err := g.sendQueue.resend() // After resending the queue, we reset the resend ticker. // This is so that we don't immediately resend the queue again, @@ -603,9 +611,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo } case *PacketACK: - gotValidACK := g.sendQueue.processACK( - m.Seq, g.resendTimeout, - ) + gotValidACK := g.sendQueue.processACK(m.Seq) if gotValidACK { // Send a signal to indicate that new diff --git a/gbn/gbn_server.go b/gbn/gbn_server.go index 488a45d..6ae03d1 100644 --- a/gbn/gbn_server.go +++ b/gbn/gbn_server.go @@ -14,12 +14,7 @@ import ( func NewServerConn(ctx context.Context, sendFunc sendBytesFunc, recvFunc recvBytesFunc, opts ...Option) (*GoBackNConn, error) { - conn := newGoBackNConn(ctx, sendFunc, recvFunc, true, DefaultN) - - // Apply functional options - for _, o := range opts { - o(conn) - } + conn := newGoBackNConn(ctx, sendFunc, recvFunc, true, DefaultN, opts...) if err := conn.serverHandshake(); err != nil { if err := conn.Close(); err != nil { @@ -39,7 +34,9 @@ func NewServerConn(ctx context.Context, sendFunc sendBytesFunc, // 2. The server then responds with a SYN message. // 3. The server waits for a SYNACK message from the client. // 4a. If the server receives the SYNACK message before a resendTimeout, the hand -// is considered complete. +// +// is considered complete. +// // 4b. If SYNACK is not received before a certain resendTimeout func (g *GoBackNConn) serverHandshake() error { // nolint:gocyclo recvChan := make(chan []byte) diff --git a/gbn/queue.go b/gbn/queue.go index 8e755b5..c676d8a 100644 --- a/gbn/queue.go +++ b/gbn/queue.go @@ -24,13 +24,7 @@ const ( awaitingTimeoutMultiplier = 3 ) -// queue is a fixed size queue with a sliding window that has a base and a top -// modulo s. -type queue struct { - // content is the current content of the queue. This is always a slice - // of length s but can contain nil elements if the queue isn't full. - content []*PacketData - +type queueConfig struct { // s is the maximum sequence number used to label packets. Packets // are labelled with incrementing sequence numbers modulo s. // s must be strictly larger than the window size, n. This @@ -40,6 +34,20 @@ type queue struct { // no way to tell. s uint8 + resendTimeout time.Duration + + sendPkt func(packet *PacketData) error +} + +// queue is a fixed size queue with a sliding window that has a base and a top +// modulo s. +type queue struct { + cfg *queueConfig + + // content is the current content of the queue. This is always a slice + // of length s but can contain nil elements if the queue isn't full. + content []*PacketData + // sequenceBase keeps track of the base of the send window and so // represents the next ack that we expect from the receiver. The // maximum value of sequenceBase is s. @@ -55,6 +63,9 @@ type queue struct { // sequenceTop must be guarded by topMtx. sequenceTop uint8 + // topMtx is used to guard sequenceTop. + topMtx sync.RWMutex + // awaitedACK defines the sequence number for the last packet in the // resend queue. If we receive an ACK for this sequence number during // the resend catch up, we wait for the duration of the resend timeout, @@ -99,24 +110,26 @@ type queue struct { // proceed to send new packets. awaitedNACKSignal chan struct{} - // topMtx is used to guard sequenceTop. - topMtx sync.RWMutex + lastResend time.Time - lastResend time.Time - handshakeTimeout time.Duration + quit chan struct{} } // newQueue creates a new queue. -func newQueue(s uint8, handshakeTimeout time.Duration) *queue { +func newQueue(cfg *queueConfig) *queue { return &queue{ - content: make([]*PacketData, s), - s: s, - handshakeTimeout: handshakeTimeout, + cfg: cfg, + content: make([]*PacketData, cfg.s), awaitedACKSignal: make(chan struct{}, 1), awaitedNACKSignal: make(chan struct{}, 1), + quit: make(chan struct{}), } } +func (q *queue) stop() { + close(q.quit) +} + // size is used to calculate the current sender queueSize. func (q *queue) size() uint8 { q.baseMtx.RLock() @@ -129,7 +142,7 @@ func (q *queue) size() uint8 { return q.sequenceTop - q.sequenceBase } - return q.sequenceTop + (q.s - q.sequenceBase) + return q.sequenceTop + (q.cfg.s - q.sequenceBase) } // addPacket adds a new packet to the queue. @@ -139,7 +152,7 @@ func (q *queue) addPacket(packet *PacketData) { packet.Seq = q.sequenceTop q.content[q.sequenceTop] = packet - q.sequenceTop = (q.sequenceTop + 1) % q.s + q.sequenceTop = (q.sequenceTop + 1) % q.cfg.s } // resend resends the current contents of the queue, by invoking the callback @@ -199,10 +212,8 @@ func (q *queue) addPacket(packet *PacketData) { // // When either of the 2 conditions above are met, we will consider both parties // to be in sync, and we can proceed to send new packets. -func (q *queue) resend(resendTimeout time.Duration, quit chan struct{}, - sendFunc func(packet *PacketData) error, -) error { - if time.Since(q.lastResend) < q.handshakeTimeout { +func (q *queue) resend() error { + if time.Since(q.lastResend) < q.cfg.resendTimeout { log.Tracef("Resent the queue recently.") return nil @@ -231,7 +242,7 @@ func (q *queue) resend(resendTimeout time.Duration, quit chan struct{}, } if q.noPingPackets(base, top) { - q.awaitedACK = (q.s + top - 1) % q.s + q.awaitedACK = (q.cfg.s + top - 1) % q.cfg.s q.awaitedNACK = top log.Tracef("Set awaitedACK to %d & awaitedNACK to %d", @@ -254,12 +265,12 @@ func (q *queue) resend(resendTimeout time.Duration, quit chan struct{}, for base != top { packet := q.content[base] - if err := sendFunc(packet); err != nil { + if err := q.cfg.sendPkt(packet); err != nil { q.awaitingCatchUpMu.Unlock() return err } - base = (base + 1) % q.s + base = (base + 1) % q.cfg.s log.Tracef("Resent %d", packet.Seq) } @@ -280,7 +291,7 @@ func (q *queue) resend(resendTimeout time.Duration, quit chan struct{}, q.awaitingCatchUpMu.Unlock() // Then await until we know that both parties are in sync. - q.awaitCatchUp(resendTimeout, quit) + q.awaitCatchUp() return nil } @@ -290,8 +301,8 @@ func (q *queue) resend(resendTimeout time.Duration, quit chan struct{}, // 3X the resend timeout, the function will also return. // See the docs for the resend function for more details on why we need to await // the awaited ACK or NACK signal. -func (q *queue) awaitCatchUp(resendTimeout time.Duration, quit chan struct{}) { - ticker := time.NewTimer(resendTimeout * awaitingTimeoutMultiplier) +func (q *queue) awaitCatchUp() { + ticker := time.NewTimer(q.cfg.resendTimeout * awaitingTimeoutMultiplier) defer ticker.Stop() log.Tracef("Awaiting catchup after resending the queue") @@ -299,7 +310,7 @@ func (q *queue) awaitCatchUp(resendTimeout time.Duration, quit chan struct{}) { catchupLoop: for { select { - case <-quit: + case <-q.quit: return case <-q.awaitedACKSignal: log.Tracef("Got awaitedACKSignal") @@ -315,15 +326,16 @@ catchupLoop: q.awaitingCatchUpMu.Lock() q.awaitingCatchUp = false - // If we time out, we need to also reset the - // channels, as they could have been sent over - // when we waiting to take the awaitingCatchUpMu - // lock above after the timeout. If we don't - // reset them, this select case would catch - // the sent signal, next time we resend the - // queue. - q.awaitedACKSignal = make(chan struct{}, 1) - q.awaitedNACKSignal = make(chan struct{}, 1) + // Drain both signal channels. + select { + case <-q.awaitedACKSignal: + default: + } + + select { + case <-q.awaitedNACKSignal: + default: + } q.awaitingCatchUpMu.Unlock() @@ -344,14 +356,14 @@ func (q *queue) noPingPackets(base, top uint8) bool { return false } - base = (base + 1) % q.s + base = (base + 1) % q.cfg.s } return true } // processACK processes an incoming ACK of a given sequence number. -func (q *queue) processACK(seq uint8, resendTimeout time.Duration) bool { +func (q *queue) processACK(seq uint8) bool { // If our queue is empty, an ACK should not have any effect. if q.size() == 0 { log.Tracef("Received ack %d, but queue is empty. Ignoring.", seq) @@ -366,7 +378,7 @@ func (q *queue) processACK(seq uint8, resendTimeout time.Duration) bool { if seq == q.awaitedACK && q.awaitingCatchUp { log.Tracef("Got awaited ACK") - q.proceedAfterTime(q.catchUpID, resendTimeout) + q.proceedAfterTime(q.catchUpID) } q.awaitingCatchUpMu.RUnlock() @@ -380,7 +392,7 @@ func (q *queue) processACK(seq uint8, resendTimeout time.Duration) bool { // has decreased. log.Tracef("Received correct ack %d", seq) - q.sequenceBase = (q.sequenceBase + 1) % q.s + q.sequenceBase = (q.sequenceBase + 1) % q.cfg.s // We did receive an ACK. return true @@ -401,7 +413,7 @@ func (q *queue) processACK(seq uint8, resendTimeout time.Duration) bool { if containsSequence(q.sequenceBase, q.sequenceTop, seq) { log.Tracef("Sequence %d is in the queue. Bump the base.", seq) - q.sequenceBase = (seq + 1) % q.s + q.sequenceBase = (seq + 1) % q.cfg.s // We did receive an ACK. return true @@ -474,7 +486,7 @@ func (q *queue) processNACK(seq uint8) (bool, bool) { // proceedAfterTime will wait for the resendTimeout and then send an // awaitedACKSignal, if we're still awaiting the resend catch up. -func (q *queue) proceedAfterTime(catchUpID int64, resendTimeout time.Duration) { +func (q *queue) proceedAfterTime(catchUpID int64) { processAwaitedACK := func() { log.Tracef("Executing proceedAfterTime") q.awaitingCatchUpMu.Lock() @@ -508,7 +520,7 @@ func (q *queue) proceedAfterTime(catchUpID int64, resendTimeout time.Duration) { // proceedAfterTime callback, as that's the time we'd expect it to take // for the other party to respond with a NACK, if the resent last packet // in the queue would lead to a NACK. - time.AfterFunc(resendTimeout, processAwaitedACK) + time.AfterFunc(q.cfg.resendTimeout, processAwaitedACK) } // containsSequence is used to determine if a number, seq, is between two other