From cc4aa3c80e413f864f8a1979ecb197460bc207cb Mon Sep 17 00:00:00 2001 From: Peter Fern Date: Fri, 25 Mar 2016 23:53:54 +1100 Subject: [PATCH] Fix a number of potential and actual race conditions Fixes #4 --- client.go | 9 +- protocol/v2.go | 10 +- protocol/v2/device/device.go | 226 ++++++++++++++++++++++++----------- protocol/v2/device/light.go | 2 +- protocol/v2/packet/packet.go | 4 +- 5 files changed, 168 insertions(+), 83 deletions(-) diff --git a/client.go b/client.go index 47a3f85..fc54cd9 100644 --- a/client.go +++ b/client.go @@ -165,10 +165,11 @@ func (c *Client) removeDeviceByID(id uint64) error { // GetLocations returns a slice of all locations known to the client, or // common.ErrNotFound if no locations are currently known. func (c *Client) GetLocations() (locations []common.Location, err error) { + c.RLock() if len(c.locations) == 0 { + c.RUnlock() return locations, common.ErrNotFound } - c.RLock() for _, location := range c.locations { locations = append(locations, location) } @@ -258,10 +259,11 @@ func (c *Client) GetLocationByLabel(label string) (common.Location, error) { // GetGroups returns a slice of all groups known to the client, or // common.ErrNotFound if no groups are currently known. func (c *Client) GetGroups() (groups []common.Group, err error) { + c.RLock() if len(c.groups) == 0 { + c.RUnlock() return groups, common.ErrNotFound } - c.RLock() for _, group := range c.groups { groups = append(groups, group) } @@ -351,10 +353,11 @@ func (c *Client) GetGroupByLabel(label string) (common.Group, error) { // GetDevices returns a slice of all devices known to the client, or // common.ErrNotFound if no devices are currently known. func (c *Client) GetDevices() (devices []common.Device, err error) { + c.RLock() if len(c.devices) == 0 { + c.RUnlock() return devices, common.ErrNotFound } - c.RLock() for _, device := range c.devices { devices = append(devices, device) } diff --git a/protocol/v2.go b/protocol/v2.go index 2ba3507..e79cdc7 100644 --- a/protocol/v2.go +++ b/protocol/v2.go @@ -93,17 +93,17 @@ func (p *V2) init() error { if err != nil { return err } - p.broadcast = &device.Light{Device: *broadcastDev} + p.broadcast = &device.Light{Device: broadcastDev} broadcastSub, err := p.broadcast.NewSubscription() if err != nil { return err } - go p.broadcastLimiter(broadcastSub.Events()) p.devices = make(map[uint64]device.GenericDevice) p.locations = make(map[string]*device.Location) p.groups = make(map[string]*device.Group) p.subscriptions = make(map[string]*common.Subscription) p.quitChan = make(chan struct{}) + go p.broadcastLimiter(broadcastSub.Events()) go p.dispatcher() p.initialized = true @@ -570,13 +570,13 @@ func (p *V2) lightOrDev(dev device.GenericDevice) device.GenericDevice { switch product { case device.ProductLifxOriginal1000, device.ProductLifxColor650, device.ProductLifxWhite800LowVoltage, device.ProductLifxWhite800HighVoltage, device.ProductLifxWhite900BR30, device.ProductLifxColor1000BR30, device.ProductLifxColor1000: p.Lock() - // Need to figure if there's a way to do this without being racey on the - // lock inside the dev d := dev.(*device.Device) - l := &device.Light{Device: *d} + d.Lock() + l := &device.Light{Device: d} common.Log.Debugf("Device is a light: %v\n", l.ID()) // Replace the known dev with our constructed light p.devices[l.ID()] = l + d.Unlock() p.Unlock() if err := l.Get(); err != nil { common.Log.Debugf("Failed getting light state: %v\n", err) diff --git a/protocol/v2/device/device.go b/protocol/v2/device/device.go index 61c8729..273ec53 100644 --- a/protocol/v2/device/device.go +++ b/protocol/v2/device/device.go @@ -55,8 +55,9 @@ const ( ProductLifxColor1000 uint32 = 22 ) +type doneChan chan struct{} type responseMap map[uint8]packet.Chan -type doneMap map[uint8]chan struct{} +type doneMap map[uint8]doneChan type Device struct { id uint64 @@ -111,6 +112,7 @@ type payloadLabel struct { } func (d *Device) init(addr *net.UDPAddr, requestSocket *net.UDPConn, timeout *time.Duration, retryInterval *time.Duration, reliable bool) { + d.Lock() d.address = addr d.requestSocket = requestSocket d.timeout = timeout @@ -122,6 +124,7 @@ func (d *Device) init(addr *net.UDPAddr, requestSocket *net.UDPConn, timeout *ti d.responseInput = make(packet.Chan, 32) d.subscriptions = make(map[string]*common.Subscription) d.quitChan = make(chan struct{}) + d.Unlock() } func (d *Device) ID() uint64 { @@ -171,9 +174,11 @@ func (d *Device) SetStateLabel(pkt *packet.Packet) error { } common.Log.Debugf("Got label (%v): %+v\n", d.id, l.Label) newLabel := stripNull(string(l.Label[:])) - if newLabel != d.label { + if newLabel != d.CachedLabel() { + d.Lock() d.label = newLabel - if err := d.publish(common.EventUpdateLabel{Label: d.label}); err != nil { + d.Unlock() + if err := d.publish(common.EventUpdateLabel{Label: newLabel}); err != nil { return err } } @@ -182,14 +187,16 @@ func (d *Device) SetStateLabel(pkt *packet.Packet) error { } func (d *Device) SetStateLocation(pkt *packet.Packet) error { - l := new(Location) + l := &Location{} if err := l.Parse(pkt); err != nil { return err } common.Log.Debugf("Got location (%v): %+v (%+v)\n", d.id, l.ID(), l.GetLabel()) newLocation := l.ID() - if newLocation != d.locationID { + if newLocation != d.CachedLocation() { + d.Lock() d.locationID = newLocation + d.Unlock() // TODO: Work out what to notify on without causing protocol version // dependency } @@ -198,14 +205,16 @@ func (d *Device) SetStateLocation(pkt *packet.Packet) error { } func (d *Device) SetStateGroup(pkt *packet.Packet) error { - l := new(Group) + l := &Group{} if err := l.Parse(pkt); err != nil { return err } common.Log.Debugf("Got group (%v): %+v (%+v)\n", d.id, l.ID(), l.GetLabel()) newGroup := l.ID() - if newGroup != d.groupID { + if newGroup != d.CachedGroup() { + d.Lock() d.groupID = newGroup + d.Unlock() // TODO: Work out what to notify on without causing protocol version // dependency } @@ -214,8 +223,9 @@ func (d *Device) SetStateGroup(pkt *packet.Packet) error { } func (d *Device) GetLabel() (string, error) { - if len(d.label) != 0 { - return d.label, nil + label := d.CachedLabel() + if len(label) != 0 { + return label, nil } pkt := packet.New(d.address, d.requestSocket) @@ -236,15 +246,15 @@ func (d *Device) GetLabel() (string, error) { return ``, err } - return d.label, nil + return d.CachedLabel(), nil } func (d *Device) SetLabel(label string) error { - if d.label == label { + if d.CachedLabel() == label { return nil } - p := new(payloadLabel) + p := &payloadLabel{} copy(p.Label[:], label) pkt := packet.New(d.address, d.requestSocket) @@ -264,13 +274,22 @@ func (d *Device) SetLabel(label string) error { common.Log.Debugf("Setting label on %v acknowledged\n", d.id) } + d.Lock() d.label = label - if err := d.publish(common.EventUpdateLabel{Label: d.label}); err != nil { + d.Unlock() + if err := d.publish(common.EventUpdateLabel{Label: label}); err != nil { return err } return nil } +func (d *Device) CachedLabel() string { + d.RLock() + label := d.label + d.RUnlock() + return label +} + func (d *Device) SetStatePower(pkt *packet.Packet) error { p := statePower{} if err := pkt.DecodePayload(&p); err != nil { @@ -278,11 +297,12 @@ func (d *Device) SetStatePower(pkt *packet.Packet) error { } common.Log.Debugf("Got power (%v): %+v\n", d.id, d.power) - if d.power != p.Level { + state := p.Level > 0 + if d.CachedPower() != state { d.Lock() d.power = p.Level d.Unlock() - if err := d.publish(common.EventUpdatePower{Power: d.power > 0}); err != nil { + if err := d.publish(common.EventUpdatePower{Power: state}); err != nil { return err } } @@ -309,22 +329,22 @@ func (d *Device) GetPower() (bool, error) { return false, err } - return d.power > 0, nil + return d.CachedPower(), nil } func (d *Device) CachedPower() bool { d.RLock() - state := d.power + state := d.power != 0 d.RUnlock() - return state != 0 + return state } func (d *Device) SetPower(state bool) error { - if state && d.power > 0 { + if state && d.CachedPower() { return nil } - p := new(payloadPower) + p := &payloadPower{} if state { p.Level = math.MaxUint16 } @@ -346,15 +366,20 @@ func (d *Device) SetPower(state bool) error { common.Log.Debugf("Setting power state on %v acknowledged\n", d.id) } + d.Lock() d.power = p.Level - if err := d.publish(common.EventUpdatePower{Power: d.power > 0}); err != nil { + d.Unlock() + if err := d.publish(common.EventUpdatePower{Power: p.Level > 0}); err != nil { return err } return nil } func (d *Device) CachedLocation() string { - return d.locationID + d.RLock() + location := d.locationID + d.RUnlock() + return location } func (d *Device) GetLocation() (ret string, err error) { @@ -376,11 +401,14 @@ func (d *Device) GetLocation() (ret string, err error) { return ret, err } - return d.locationID, nil + return d.CachedLocation(), nil } func (d *Device) CachedGroup() string { - return d.groupID + d.RLock() + group := d.groupID + d.RUnlock() + return group } func (d *Device) GetGroup() (ret string, err error) { @@ -402,36 +430,38 @@ func (d *Device) GetGroup() (ret string, err error) { return ret, err } - return d.groupID, nil + return d.CachedGroup(), nil } func (d *Device) GetHardwareVendor() (uint32, error) { - if d.hardwareVersion.Product != 0 { - return d.hardwareVersion.Vendor, nil + if d.CachedHardwareProduct() != 0 { + return d.CachedHardwareVendor(), nil } + _, err := d.GetHardwareVersion() if err != nil { return 0, err } - return d.hardwareVersion.Vendor, nil + return d.CachedHardwareVendor(), nil } func (d *Device) GetHardwareProduct() (uint32, error) { - if d.hardwareVersion.Product != 0 { - return d.hardwareVersion.Product, nil + if d.CachedHardwareProduct() != 0 { + return d.CachedHardwareProduct(), nil } + _, err := d.GetHardwareVersion() if err != nil { return 0, err } - return d.hardwareVersion.Product, nil + return d.CachedHardwareProduct(), nil } func (d *Device) GetHardwareVersion() (uint32, error) { - if d.hardwareVersion.Product != 0 { - return d.hardwareVersion.Version, nil + if d.CachedHardwareProduct() != 0 { + return d.CachedHardwareVersion(), nil } pkt := packet.New(d.address, d.requestSocket) @@ -447,12 +477,38 @@ func (d *Device) GetHardwareVersion() (uint32, error) { return 0, err } - if err = pktResponse.Result.DecodePayload(&d.hardwareVersion); err != nil { + v := stateVersion{} + if err = pktResponse.Result.DecodePayload(&v); err != nil { return 0, err } - common.Log.Debugf("Got hardware version (%v): %+v\n", d.id, d.hardwareVersion) + common.Log.Debugf("Got hardware version (%v): %+v\n", d.id, v) + + d.Lock() + d.hardwareVersion = v + d.Unlock() + + return d.CachedHardwareVersion(), nil +} + +func (d *Device) CachedHardwareVersion() uint32 { + d.RLock() + version := d.hardwareVersion.Version + d.RUnlock() + return version +} + +func (d *Device) CachedHardwareVendor() uint32 { + d.RLock() + vendor := d.hardwareVersion.Vendor + d.RUnlock() + return vendor +} - return d.hardwareVersion.Version, nil +func (d *Device) CachedHardwareProduct() uint32 { + d.RLock() + product := d.hardwareVersion.Product + d.RUnlock() + return product } func (d *Device) Handle(pkt *packet.Packet) { @@ -464,7 +520,9 @@ func (d *Device) GetAddress() *net.UDPAddr { } func (d *Device) ResetLimiter() { + d.Lock() d.limiter.Reset(shared.RateLimit) + d.Unlock() } func (d *Device) resetLimiter(broadcast bool) { @@ -501,30 +559,17 @@ func (d *Device) Send(pkt *packet.Packet, ackRequired, responseRequired bool) (p pkt.SetResRequired(true) } if ackRequired || responseRequired { - inputChan := make(packet.Chan) - doneChan := make(chan struct{}) - - d.Lock() - d.sequence++ - if d.sequence == 0 { - d.sequence++ - } - seq := d.sequence - d.responseMap[seq] = inputChan - d.doneMap[seq] = doneChan + seq, input, done := d.addSeq() pkt.SetSequence(seq) - d.Unlock() go func() { defer func() { - close(doneChan) + close(done) }() var ( - ok bool - timeout <-chan time.Time - pktResponse = packet.Response{} - ticker = time.NewTicker(*d.retryInterval) + timeout <-chan time.Time + ticker = time.NewTicker(*d.retryInterval) ) if d.timeout == nil || *d.timeout == 0 { @@ -535,7 +580,7 @@ func (d *Device) Send(pkt *packet.Packet, ackRequired, responseRequired bool) (p for { select { - case pktResponse, ok = <-inputChan: + case pktResponse, ok := <-input: if !ok { close(proxyChan) return @@ -551,12 +596,14 @@ func (d *Device) Send(pkt *packet.Packet, ackRequired, responseRequired bool) (p return case <-ticker.C: common.Log.Debugf("Retrying send after %d milliseconds: %+v\n", *d.retryInterval/time.Millisecond, *pkt) + pktResponse := packet.Response{} if err := pkt.Write(); err != nil { pktResponse.Error = err proxyChan <- pktResponse return } case <-timeout: + pktResponse := packet.Response{} pktResponse.Error = common.ErrTimeout proxyChan <- pktResponse return @@ -609,13 +656,12 @@ func (d *Device) Close() error { func (d *Device) handler() { var ( - pktResponse packet.Response - ok bool - ch packet.Chan - done chan struct{} + ok bool + ch packet.Chan + done chan struct{} ) - for { + for { select { case <-d.quitChan: d.Lock() @@ -627,15 +673,10 @@ func (d *Device) handler() { } d.Unlock() return - case pktResponse = <-d.responseInput: + case pktResponse := <-d.responseInput: common.Log.Debugf("Handling packet on device %v: %+v\n", d.id, pktResponse) seq := pktResponse.Result.GetSequence() - d.RLock() - ch, ok = d.responseMap[seq] - if ok { - done, ok = d.doneMap[seq] - } - d.RUnlock() + ch, done, ok = d.getSeq(seq) if !ok { common.Log.Warnf("Couldn't find requestor for seq %v on device %v: %+v\n", seq, d.id, pktResponse) continue @@ -644,16 +685,54 @@ func (d *Device) handler() { select { case ch <- pktResponse: case <-done: - d.Lock() - close(ch) - delete(d.responseMap, seq) - delete(d.doneMap, seq) - d.Unlock() + d.delSeq(seq) } } } } +func (d *Device) addSeq() (seq uint8, input packet.Chan, done doneChan) { + input = make(packet.Chan) + done = make(doneChan) + + d.Lock() + d.sequence++ + if d.sequence == 0 { + d.sequence++ + } + seq = d.sequence + d.responseMap[seq] = input + d.doneMap[seq] = done + d.Unlock() + + return seq, input, done +} + +func (d *Device) getSeq(seq uint8) (ch packet.Chan, done doneChan, ok bool) { + d.RLock() + ch, ok = d.responseMap[seq] + if !ok { + d.RUnlock() + return nil, nil, false + } + done, ok = d.doneMap[seq] + d.RUnlock() + + return ch, done, ok +} + +func (d *Device) delSeq(seq uint8) { + ch, _, ok := d.getSeq(seq) + if !ok { + return + } + d.Lock() + close(ch) + delete(d.responseMap, seq) + delete(d.doneMap, seq) + d.Unlock() +} + // Pushes an event to subscribers func (d *Device) publish(event interface{}) error { d.RLock() @@ -673,12 +752,12 @@ func (d *Device) publish(event interface{}) error { } func New(addr *net.UDPAddr, requestSocket *net.UDPConn, timeout *time.Duration, retryInterval *time.Duration, reliable bool, pkt *packet.Packet) (*Device, error) { - d := new(Device) + d := &Device{} d.init(addr, requestSocket, timeout, retryInterval, reliable) if pkt != nil { d.id = pkt.Target - service := new(stateService) + service := &stateService{} if err := pkt.DecodePayload(service); err != nil { return nil, err @@ -686,6 +765,7 @@ func New(addr *net.UDPAddr, requestSocket *net.UDPConn, timeout *time.Duration, d.address.Port = int(service.Port) } + go d.handler() return d, nil diff --git a/protocol/v2/device/light.go b/protocol/v2/device/light.go index e11be0b..9985a4a 100644 --- a/protocol/v2/device/light.go +++ b/protocol/v2/device/light.go @@ -19,7 +19,7 @@ const ( ) type Light struct { - Device + *Device color common.Color } diff --git a/protocol/v2/packet/packet.go b/protocol/v2/packet/packet.go index 2cc3a96..b4753ed 100644 --- a/protocol/v2/packet/packet.go +++ b/protocol/v2/packet/packet.go @@ -214,8 +214,10 @@ func (p *Packet) DecodePayload(dest interface{}) error { p.mutex.RUnlock() return common.ErrProtocol } - err := binary.Read(bytes.NewBuffer(p.payload), binary.LittleEndian, dest) p.mutex.RUnlock() + p.mutex.Lock() + err := binary.Read(bytes.NewBuffer(p.payload), binary.LittleEndian, dest) + p.mutex.Unlock() return err }