Skip to content

Commit

Permalink
feat: remove unnecessary use of WaitGroup for protocol done signal (#608
Browse files Browse the repository at this point in the history
)
  • Loading branch information
rakshasa authored May 5, 2024
1 parent 2e915ed commit c174202
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 75 deletions.
54 changes: 21 additions & 33 deletions protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"github.com/blinklabs-io/gouroboros/cbor"
"github.com/blinklabs-io/gouroboros/connection"
"github.com/blinklabs-io/gouroboros/muxer"
"github.com/blinklabs-io/gouroboros/utils"
)

// This is completely arbitrary, but the line had to be drawn somewhere
Expand All @@ -34,15 +33,16 @@ const maxMessagesPerSegment = 20
// Protocol implements the base functionality of an Ouroboros mini-protocol
type Protocol struct {
config ProtocolConfig
doneChan chan struct{}
muxerSendChan chan *muxer.Segment
muxerRecvChan chan *muxer.Segment
muxerDoneChan chan bool
sendQueueChan chan Message
recvDoneChan chan struct{}
recvReadyChan chan bool
sendDoneChan chan struct{}
sendReadyChan chan bool
stateTransitionChan chan<- protocolStateTransition
doneSignal *utils.DoneSignal
waitGroup sync.WaitGroup
onceStart sync.Once
}

Expand Down Expand Up @@ -105,8 +105,10 @@ type MessageFromCborFunc func(uint, []byte) (Message, error)
// New returns a new Protocol object
func New(config ProtocolConfig) *Protocol {
p := &Protocol{
config: config,
doneSignal: utils.NewDoneSignal(),
config: config,
doneChan: make(chan struct{}),
recvDoneChan: make(chan struct{}),
sendDoneChan: make(chan struct{}),
}
return p
}
Expand All @@ -133,7 +135,11 @@ func (p *Protocol) Start() {
p.stateTransitionChan = stateTransitionChan

// Start our send and receive Goroutines
p.waitGroup.Add(2)
go func() {
<-p.recvDoneChan
<-p.sendDoneChan
close(p.doneChan)
}()

go p.stateLoop(stateTransitionChan)
go p.recvLoop()
Expand All @@ -153,7 +159,7 @@ func (p *Protocol) Role() ProtocolRole {

// DoneChan returns the channel used to signal protocol shutdown
func (p *Protocol) DoneChan() <-chan struct{} {
return p.doneSignal.GetCh()
return p.doneChan
}

// SendMessage appends a message to the send queue
Expand All @@ -176,17 +182,16 @@ func (p *Protocol) SendError(err error) {

func (p *Protocol) sendLoop() {
defer func() {
p.waitGroup.Done()
// Close muxer send channel
// We are responsible for closing this channel as the sender, even through it
// was created by the muxer
close(p.muxerSendChan)
p.doneSignal.Close()
close(p.sendDoneChan)
}()

for {
select {
case <-p.doneSignal.GetCh():
case <-p.recvDoneChan:
// Break out of send loop if we're shutting down
return
case <-p.sendReadyChan:
Expand All @@ -200,7 +205,7 @@ func (p *Protocol) sendLoop() {
for {
// Get next message from send queue
select {
case <-p.doneSignal.GetCh():
case <-p.recvDoneChan:
// Break out of send loop if we're shutting down
return
case msg, ok := <-p.sendQueueChan:
Expand Down Expand Up @@ -285,8 +290,7 @@ func (p *Protocol) sendLoop() {

func (p *Protocol) recvLoop() {
defer func() {
p.waitGroup.Done()
p.doneSignal.Close()
close(p.recvDoneChan)
}()

leftoverData := false
Expand All @@ -298,7 +302,7 @@ func (p *Protocol) recvLoop() {
if !leftoverData {
// Wait for segment
select {
case <-p.doneSignal.GetCh():
case <-p.sendDoneChan:
// Break out of receive loop if we're shutting down
return
case <-p.muxerDoneChan:
Expand All @@ -314,7 +318,7 @@ func (p *Protocol) recvLoop() {
leftoverData = false
// Wait until ready to receive based on state map
select {
case <-p.doneSignal.GetCh():
case <-p.sendDoneChan:
// Break out of receive loop if we're shutting down
return
case <-p.muxerDoneChan:
Expand Down Expand Up @@ -431,9 +435,6 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) {
return transitionTimer.C
}

protocolDoneChan := p.doneSignal.GetCh()
stateDoneChan := make(chan struct{})

setState(p.config.InitialState)

for {
Expand Down Expand Up @@ -467,24 +468,11 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) {
),
)

case <-protocolDoneChan:
// Disable this case so it doesn't block
protocolDoneChan = nil

// Wait for all other goroutines to finish before shutting down the state handler
go func() {
p.waitGroup.Wait()

close(stateDoneChan)
}()

case <-stateDoneChan:
// All other goroutines have finished, so we can stop the timer and return
case <-p.doneChan:
// Disable any previous state transition timer, as they are no longer needed
if transitionTimer != nil && !transitionTimer.Stop() {
<-transitionTimer.C
}
transitionTimer = nil

return
}
}
Expand Down
42 changes: 0 additions & 42 deletions utils/utils.go

This file was deleted.

0 comments on commit c174202

Please sign in to comment.