Skip to content

Commit

Permalink
Merge pull request #677 from Mirantis/ivan4th/fix-storage-race
Browse files Browse the repository at this point in the history
Fix race condition in libvirt storage
  • Loading branch information
pigmej authored May 28, 2018
2 parents e2a703f + 0633d86 commit 03697fe
Showing 1 changed file with 28 additions and 6 deletions.
34 changes: 28 additions & 6 deletions pkg/libvirttools/libvirt_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package libvirttools

import (
"fmt"
"sync"

"github.com/golang/glog"
libvirt "github.com/libvirt/libvirt-go"
Expand All @@ -28,6 +29,11 @@ import (
)

type libvirtStorageConnection struct {
// Trying to do several storage-related operations at the same time
// may cause a race condition.
// As right now Virtlet uses just one storage pool, a single sync.Mutex
// is enough for handling the storage.
sync.Mutex
conn libvirtConnection
}

Expand All @@ -49,7 +55,7 @@ func (sc *libvirtStorageConnection) CreateStoragePool(def *libvirtxml.StoragePoo
if err != nil {
return nil, err
}
return &libvirtStoragePool{conn: sc.conn, p: p.(*libvirt.StoragePool)}, nil
return &libvirtStoragePool{Mutex: &sc.Mutex, conn: sc.conn, p: p.(*libvirt.StoragePool)}, nil
}

func (sc *libvirtStorageConnection) LookupStoragePoolByName(name string) (virt.StoragePool, error) {
Expand All @@ -63,17 +69,20 @@ func (sc *libvirtStorageConnection) LookupStoragePoolByName(name string) (virt.S
}
return nil, err
}
return &libvirtStoragePool{conn: sc.conn, p: p.(*libvirt.StoragePool)}, nil
return &libvirtStoragePool{Mutex: &sc.Mutex, conn: sc.conn, p: p.(*libvirt.StoragePool)}, nil
}

type libvirtStoragePool struct {
*sync.Mutex
conn libvirtConnection
p *libvirt.StoragePool
}

var _ virt.StoragePool = &libvirtStoragePool{}

func (pool *libvirtStoragePool) CreateStorageVol(def *libvirtxml.StorageVolume) (virt.StorageVolume, error) {
pool.Lock()
defer pool.Unlock()
xml, err := def.Marshal()
if err != nil {
return nil, err
Expand All @@ -90,10 +99,12 @@ func (pool *libvirtStoragePool) CreateStorageVol(def *libvirtxml.StorageVolume)
if err := pool.p.Refresh(0); err != nil {
return nil, fmt.Errorf("failed to refresh the storage pool: %v", err)
}
return &libvirtStorageVolume{name: def.Name, v: v}, nil
return &libvirtStorageVolume{Mutex: pool.Mutex, name: def.Name, v: v}, nil
}

func (pool *libvirtStoragePool) ListAllVolumes() ([]virt.StorageVolume, error) {
pool.Lock()
defer pool.Unlock()
volumes, err := pool.p.ListAllStorageVolumes(0)
if err != nil {
return nil, err
Expand All @@ -106,12 +117,14 @@ func (pool *libvirtStoragePool) ListAllVolumes() ([]virt.StorageVolume, error) {
}
// need to make a copy here
curVolume := v
r[n] = &libvirtStorageVolume{name: name, v: &curVolume}
r[n] = &libvirtStorageVolume{Mutex: pool.Mutex, name: name, v: &curVolume}
}
return r, nil
}

func (pool *libvirtStoragePool) LookupVolumeByName(name string) (virt.StorageVolume, error) {
pool.Lock()
defer pool.Unlock()
v, err := pool.p.LookupStorageVolByName(name)
if err != nil {
libvirtErr, ok := err.(libvirt.Error)
Expand All @@ -120,7 +133,7 @@ func (pool *libvirtStoragePool) LookupVolumeByName(name string) (virt.StorageVol
}
return nil, err
}
return &libvirtStorageVolume{name: name, v: v}, nil
return &libvirtStorageVolume{Mutex: pool.Mutex, name: name, v: v}, nil
}

func (pool *libvirtStoragePool) RemoveVolumeByName(name string) error {
Expand All @@ -136,6 +149,7 @@ func (pool *libvirtStoragePool) RemoveVolumeByName(name string) error {
}

type libvirtStorageVolume struct {
*sync.Mutex
name string
v *libvirt.StorageVol
}
Expand All @@ -147,6 +161,8 @@ func (volume *libvirtStorageVolume) Name() string {
}

func (volume *libvirtStorageVolume) Size() (uint64, error) {
volume.Lock()
defer volume.Unlock()
info, err := volume.v.GetInfo()
if err != nil {
return 0, err
Expand All @@ -155,15 +171,21 @@ func (volume *libvirtStorageVolume) Size() (uint64, error) {
}

func (volume *libvirtStorageVolume) Path() (string, error) {
volume.Lock()
defer volume.Unlock()
return volume.v.GetPath()
}

func (volume *libvirtStorageVolume) Remove() error {
volume.Lock()
defer volume.Unlock()
return volume.v.Delete(0)
}

func (volume *libvirtStorageVolume) Format() error {
volPath, err := volume.Path()
volume.Lock()
defer volume.Unlock()
volPath, err := volume.v.GetPath()
if err != nil {
return fmt.Errorf("can't get volume path: %v", err)
}
Expand Down

0 comments on commit 03697fe

Please sign in to comment.