diff --git a/amqptest/server/exchange.go b/amqptest/server/exchange.go index b7ca940..2683549 100644 --- a/amqptest/server/exchange.go +++ b/amqptest/server/exchange.go @@ -1,6 +1,9 @@ package server -import "fmt" +import ( + "fmt" + "sync" +) type Exchange interface { route(route string, d *Delivery) error @@ -11,24 +14,32 @@ type Exchange interface { type TopicExchange struct { name string bindings map[string]*Queue + mu *sync.RWMutex } func NewTopicExchange(name string) *TopicExchange { return &TopicExchange{ name: name, bindings: make(map[string]*Queue), + mu: &sync.RWMutex{}, } } func (t *TopicExchange) addBinding(route string, q *Queue) { + t.mu.Lock() + defer t.mu.Unlock() t.bindings[route] = q } func (t *TopicExchange) delBinding(route string) { + t.mu.Lock() + defer t.mu.Unlock() delete(t.bindings, route) } func (t *TopicExchange) route(route string, d *Delivery) error { + t.mu.RLock() + defer t.mu.RUnlock() for bname, q := range t.bindings { if topicMatch(bname, route) { q.data <- d @@ -43,16 +54,20 @@ func (t *TopicExchange) route(route string, d *Delivery) error { type DirectExchange struct { name string bindings map[string]*Queue + mu *sync.RWMutex } func NewDirectExchange(name string) *DirectExchange { return &DirectExchange{ name: name, bindings: make(map[string]*Queue), + mu: &sync.RWMutex{}, } } func (d *DirectExchange) addBinding(route string, q *Queue) { + d.mu.Lock() + defer d.mu.Unlock() if d.bindings == nil { d.bindings = make(map[string]*Queue) } @@ -61,10 +76,14 @@ func (d *DirectExchange) addBinding(route string, q *Queue) { } func (d *DirectExchange) delBinding(route string) { + d.mu.Lock() + defer d.mu.Unlock() delete(d.bindings, route) } func (d *DirectExchange) route(route string, delivery *Delivery) error { + d.mu.RLock() + defer d.mu.RUnlock() if q, ok := d.bindings[route]; ok { q.data <- delivery return nil diff --git a/amqptest/server/server.go b/amqptest/server/server.go index 0680434..fe74abb 100644 --- a/amqptest/server/server.go +++ b/amqptest/server/server.go @@ -32,6 +32,7 @@ type AMQPServer struct { notifyChans map[string]*utils.ErrBroadcast channels map[string][]*Channel vhost *VHost + muChannels *sync.RWMutex } // NewServer returns a new fake amqp server @@ -41,13 +42,14 @@ func newServer(amqpuri string) *AMQPServer { notifyChans: make(map[string]*utils.ErrBroadcast), channels: make(map[string][]*Channel), vhost: NewVHost("/"), + muChannels: &sync.RWMutex{}, } } // CreateChannel returns a new fresh channel func (s *AMQPServer) CreateChannel(connID string, conn wabbit.Conn) (wabbit.Channel, error) { - mu.Lock() - defer mu.Unlock() + s.muChannels.Lock() + defer s.muChannels.Unlock() if _, ok := s.channels[connID]; !ok { s.channels[connID] = make([]*Channel, 0, MaxChannels) diff --git a/amqptest/server/vhost.go b/amqptest/server/vhost.go index 320b1e2..b0aa59d 100644 --- a/amqptest/server/vhost.go +++ b/amqptest/server/vhost.go @@ -27,13 +27,11 @@ func NewVHost(name string) *VHost { func (v *VHost) createDefaultExchanges() { exchs := make(map[string]Exchange) - exchs["amq.topic"] = &TopicExchange{} - exchs["amq.direct"] = &DirectExchange{} - exchs["topic"] = &TopicExchange{} - exchs["direct"] = &DirectExchange{} - exchs[""] = &DirectExchange{ - name: "amq.direct", - } + exchs["amq.topic"] = NewTopicExchange("amq.topic") + exchs["amq.direct"] = NewDirectExchange("amq.direct") + exchs["topic"] = NewTopicExchange("topic") + exchs["direct"] = NewDirectExchange("direct") + exchs[""] = NewDirectExchange("amq.direct") v.exchanges = exchs }