diff --git a/auctioneer_test.go b/auctioneer_test.go index 51a88de98..91edbc919 100644 --- a/auctioneer_test.go +++ b/auctioneer_test.go @@ -54,13 +54,22 @@ func randomPubKey(t *testing.T) *btcec.PublicKey { return pub } +type fetchStateReq struct { + resp chan AuctionState +} + type mockAuctioneerState struct { sync.RWMutex - state AuctionState - + state AuctionState stateTransitions chan AuctionState + // We use a channel to synchronize acccess to the state, such that we + // can fetch the current state even though we are waiting for a state + // transition to happen. + stateSem chan struct{} + fetchStateChan chan *fetchStateReq + acct *account.Auctioneer orders map[orderT.Nonce]order.ServerOrder @@ -84,10 +93,16 @@ func newMockAuctioneerState(batchKey *btcec.PublicKey, stateTransitionsBuffer = 100 } + // Initialize semaphore with one element. + stateSem := make(chan struct{}, 1) + stateSem <- struct{}{} + return &mockAuctioneerState{ state: DefaultState{}, batchKey: batchKey, stateTransitions: make(chan AuctionState, stateTransitionsBuffer), + fetchStateChan: make(chan *fetchStateReq), + stateSem: stateSem, orders: make(map[orderT.Nonce]order.ServerOrder), batchStates: make(map[orderT.BatchID]bool), snapshots: make(map[orderT.BatchID]*subastadb.BatchSnapshot), @@ -127,20 +142,51 @@ func (m *mockAuctioneerState) BatchKey(context.Context) (*btcec.PublicKey, error } func (m *mockAuctioneerState) UpdateAuctionState(state AuctionState) error { - m.stateTransitions <- state + // To update the auction state, we must obtain the exclusive state + // access semaphore. + <-m.stateSem + defer func() { + m.stateSem <- struct{}{} + }() + + for { + select { + // When the state transition is read, we can update + // our state variable and return. + case m.stateTransitions <- state: + m.state = state + return nil - m.Lock() - m.state = state - m.Unlock() + // If a request to read the current state comes in, we return + // the old state, as the state transition hasn't been triggered + // yet. + case req := <-m.fetchStateChan: + req.resp <- m.state + } + } - return nil } func (m *mockAuctioneerState) AuctionState() (AuctionState, error) { - m.RLock() - defer m.RUnlock() + req := &fetchStateReq{ + resp: make(chan AuctionState, 1), + } - return m.state, nil + select { + // If we get the semaphore we have exclusive access to the state and + // can return it directly. + case <-m.stateSem: + defer func() { + m.stateSem <- struct{}{} + }() + + return m.state, nil + + // Otherwise some other goroutine has the semaphore, so we ask it to + // return the state to us. + case m.fetchStateChan <- req: + return <-req.resp, nil + } } func (m *mockAuctioneerState) ConfirmBatch(ctx context.Context,