diff --git a/concurrent_test.go b/concurrent_test.go index c28ef47f2..e1a5b5883 100644 --- a/concurrent_test.go +++ b/concurrent_test.go @@ -3,12 +3,15 @@ package bbolt_test import ( crand "crypto/rand" "encoding/hex" + "encoding/json" "fmt" mrand "math/rand" "os" "path/filepath" "reflect" + "sort" "strings" + "sync" "testing" "time" "unicode/utf8" @@ -21,6 +24,13 @@ import ( "go.etcd.io/bbolt/internal/common" ) +/* +TestConcurrentReadAndWrite verifies: + 1. Repeatable read: a read transaction should always see the same data + view during its lifecycle; + 2. Any data written by a writing transaction should be visible to any + following reading transactions (with txid >= previous writing txid). +*/ func TestConcurrentReadAndWrite(t *testing.T) { bucket := []byte("data") keys := []string{"key0", "key1", "key2", "key3", "key4", "key5", "key6", "key7", "key8", "key9"} @@ -106,7 +116,7 @@ func concurrentReadAndWrite(t *testing.T, minWriteBytes, maxWriteBytes int, testDuration time.Duration) { - // prepare the db + t.Log("Preparing db.") db := btesting.MustCreateDB(t) err := db.Update(func(tx *bolt.Tx) error { _, err := tx.CreateBucket(bucket) @@ -114,26 +124,48 @@ func concurrentReadAndWrite(t *testing.T, }) require.NoError(t, err) + t.Log("Starting workers.") + records := runWorkers(t, + db, bucket, keys, + readerCount, minReadInterval, maxReadInterval, + minWriteInterval, maxWriteInterval, minWriteBytes, maxWriteBytes, + testDuration) + + t.Log("Analyzing the history records.") + if err := validateSerializable(records); err != nil { + t.Errorf("The history records are not serializable:\n %v", err) + } + + saveDataIfFailed(t, db, records) + + // TODO (ahrtr): + // 1. intentionally inject a random failpoint. + // 2. check db consistency at the end. +} + +/* +********************************************************* +Data structures and functions/methods for running +concurrent workers, including reading and writing workers +********************************************************* +*/ +func runWorkers(t *testing.T, + db *btesting.DB, + bucket []byte, + keys []string, + readerCount int, + minReadInterval, maxReadInterval time.Duration, + minWriteInterval, maxWriteInterval time.Duration, + minWriteBytes, maxWriteBytes int, + testDuration time.Duration) historyRecords { stopCh := make(chan struct{}, 1) errCh := make(chan error, readerCount+1) - // start readonly transactions - g := new(errgroup.Group) - for i := 0; i < readerCount; i++ { - reader := &readWorker{ - db: db, - bucket: bucket, - keys: keys, - minReadInterval: minReadInterval, - maxReadInterval: maxReadInterval, - errCh: errCh, - stopCh: stopCh, - t: t, - } - g.Go(reader.run) - } + var mu sync.Mutex + var rs historyRecords // start write transaction + g := new(errgroup.Group) writer := writeWorker{ db: db, bucket: bucket, @@ -147,7 +179,35 @@ func concurrentReadAndWrite(t *testing.T, stopCh: stopCh, t: t, } - g.Go(writer.run) + g.Go(func() error { + wrs, err := writer.run() + mu.Lock() + rs = append(rs, wrs...) + mu.Unlock() + return err + }) + + // start readonly transactions + for i := 0; i < readerCount; i++ { + reader := &readWorker{ + db: db, + bucket: bucket, + keys: keys, + minReadInterval: minReadInterval, + maxReadInterval: maxReadInterval, + + errCh: errCh, + stopCh: stopCh, + t: t, + } + g.Go(func() error { + rrs, err := reader.run() + mu.Lock() + rs = append(rs, rrs...) + mu.Unlock() + return err + }) + } t.Logf("Keep reading and writing transactions running for about %s.", testDuration) select { @@ -156,17 +216,12 @@ func concurrentReadAndWrite(t *testing.T, } close(stopCh) - t.Log("Wait for all transactions to finish.") + t.Log("Waiting for all transactions to finish.") if err := g.Wait(); err != nil { t.Errorf("Received error: %v", err) } - saveDataIfFailed(t, db) - - // TODO (ahrtr): - // 1. intentionally inject a random failpoint. - // 2. validate the linearizablity: each reading transaction - // should read the value written by previous writing transaction. + return rs } type readWorker struct { @@ -177,27 +232,29 @@ type readWorker struct { minReadInterval time.Duration maxReadInterval time.Duration - errCh chan error - stopCh chan struct{} + + errCh chan error + stopCh chan struct{} t *testing.T } -func (reader *readWorker) run() error { +func (r *readWorker) run() (historyRecords, error) { + var rs historyRecords for { select { - case <-reader.stopCh: - reader.t.Log("Reading transaction finished.") - return nil + case <-r.stopCh: + r.t.Log("Reading transaction finished.") + return rs, nil default: } - err := reader.db.View(func(tx *bolt.Tx) error { - b := tx.Bucket(reader.bucket) + err := r.db.View(func(tx *bolt.Tx) error { + b := tx.Bucket(r.bucket) - selectedKey := reader.keys[mrand.Intn(len(reader.keys))] + selectedKey := r.keys[mrand.Intn(len(r.keys))] initialVal := b.Get([]byte(selectedKey)) - time.Sleep(randomDurationInRange(reader.minReadInterval, reader.maxReadInterval)) + time.Sleep(randomDurationInRange(r.minReadInterval, r.maxReadInterval)) val := b.Get([]byte(selectedKey)) if !reflect.DeepEqual(initialVal, val) { @@ -205,14 +262,24 @@ func (reader *readWorker) run() error { selectedKey, formatBytes(initialVal), formatBytes(val)) } + clonedVal := make([]byte, len(val)) + copy(clonedVal, val) + + rs = append(rs, historyRecord{ + OperationType: Read, + Key: selectedKey, + Value: clonedVal, + Txid: tx.ID(), + }) + return nil }) if err != nil { readErr := fmt.Errorf("[reader error]: %w", err) - reader.t.Log(readErr) - reader.errCh <- readErr - return readErr + r.t.Log(readErr) + r.errCh <- readErr + return rs, readErr } } } @@ -227,43 +294,55 @@ type writeWorker struct { maxWriteBytes int minWriteInterval time.Duration maxWriteInterval time.Duration - errCh chan error - stopCh chan struct{} + + errCh chan error + stopCh chan struct{} t *testing.T } -func (writer *writeWorker) run() error { +func (w *writeWorker) run() (historyRecords, error) { + var rs historyRecords for { select { - case <-writer.stopCh: - writer.t.Log("Writing transaction finished.") - return nil + case <-w.stopCh: + w.t.Log("Writing transaction finished.") + return rs, nil default: } - err := writer.db.Update(func(tx *bolt.Tx) error { - b := tx.Bucket(writer.bucket) + err := w.db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket(w.bucket) - selectedKey := writer.keys[mrand.Intn(len(writer.keys))] + selectedKey := w.keys[mrand.Intn(len(w.keys))] - valueBytes := randomIntInRange(writer.minWriteBytes, writer.maxWriteBytes) + valueBytes := randomIntInRange(w.minWriteBytes, w.maxWriteBytes) v := make([]byte, valueBytes) if _, cErr := crand.Read(v); cErr != nil { return cErr } - return b.Put([]byte(selectedKey), v) + putErr := b.Put([]byte(selectedKey), v) + if putErr == nil { + rs = append(rs, historyRecord{ + OperationType: Write, + Key: selectedKey, + Value: v, + Txid: tx.ID(), + }) + } + + return putErr }) if err != nil { writeErr := fmt.Errorf("[writer error]: %w", err) - writer.t.Log(writeErr) - writer.errCh <- writeErr - return writeErr + w.t.Log(writeErr) + w.errCh <- writeErr + return rs, writeErr } - time.Sleep(randomDurationInRange(writer.minWriteInterval, writer.maxWriteInterval)) + time.Sleep(randomDurationInRange(w.minWriteInterval, w.maxWriteInterval)) } } @@ -285,18 +364,41 @@ func formatBytes(val []byte) string { return hex.EncodeToString(val) } -func saveDataIfFailed(t *testing.T, db *btesting.DB) { +/* +********************************************************* +Functions for persisting test data, including db file +and operation history +********************************************************* +*/ +func saveDataIfFailed(t *testing.T, db *btesting.DB, rs historyRecords) { if t.Failed() { if err := db.Close(); err != nil { t.Errorf("Failed to close db: %v", err) } backupPath := testResultsDirectory(t) - targetFile := filepath.Join(backupPath, "db.bak") + backupDB(t, db, backupPath) + persistHistoryRecords(t, rs, backupPath) + } +} - t.Logf("Saving the DB file to %s", targetFile) - err := common.CopyFile(db.Path(), targetFile) +func backupDB(t *testing.T, db *btesting.DB, path string) { + targetFile := filepath.Join(path, "db.bak") + t.Logf("Saving the DB file to %s", targetFile) + err := common.CopyFile(db.Path(), targetFile) + require.NoError(t, err) + t.Logf("DB file saved to %s", targetFile) +} + +func persistHistoryRecords(t *testing.T, rs historyRecords, path string) { + recordFilePath := filepath.Join(path, "history_records.json") + t.Logf("Saving history records to %s", recordFilePath) + recordFile, err := os.OpenFile(recordFilePath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0755) + require.NoError(t, err) + defer recordFile.Close() + encoder := json.NewEncoder(recordFile) + for _, rec := range rs { + err := encoder.Encode(rec) require.NoError(t, err) - t.Logf("DB file saved to %s", targetFile) } } @@ -321,3 +423,89 @@ func testResultsDirectory(t *testing.T) string { return path } + +/* +********************************************************* +Data structures and functions for analyzing history records +********************************************************* +*/ +type OperationType string + +const ( + Read OperationType = "read" + Write OperationType = "write" +) + +type historyRecord struct { + OperationType OperationType `json:"operationType,omitempty"` + Txid int `json:"txid,omitempty"` + Key string `json:"key,omitempty"` + Value []byte `json:"value,omitempty"` +} + +type historyRecords []historyRecord + +func (rs historyRecords) Len() int { + return len(rs) +} + +func (rs historyRecords) Less(i, j int) bool { + // Sorted by key firstly: all records with the same key are grouped together. + keyCmp := strings.Compare(rs[i].Key, rs[j].Key) + if keyCmp != 0 { + return keyCmp < 0 + } + + // Sorted by txid + if rs[i].Txid != rs[j].Txid { + return rs[i].Txid < rs[j].Txid + } + + // Sorted by workerType: put writer before reader if they have the same txid. + if rs[i].OperationType == Write { + return true + } + + return false +} + +func (rs historyRecords) Swap(i, j int) { + rs[i], rs[j] = rs[j], rs[i] +} + +func validateSerializable(rs historyRecords) error { + sort.Sort(rs) + + lastWriteKeyValueMap := make(map[string]*historyRecord) + + for _, rec := range rs { + if v, ok := lastWriteKeyValueMap[rec.Key]; ok { + if rec.OperationType == Write { + v.Value = rec.Value + v.Txid = rec.Txid + } else { + if !reflect.DeepEqual(v.Value, rec.Value) { + return fmt.Errorf("reader[txid: %d, key: %s] read %x, \nbut writer[txid: %d, key: %s] wrote %x", + rec.Txid, rec.Key, rec.Value, + v.Txid, v.Key, v.Value) + } + } + } else { + if rec.OperationType == Write { + lastWriteKeyValueMap[rec.Key] = &historyRecord{ + OperationType: Write, + Key: rec.Key, + Value: rec.Value, + Txid: rec.Txid, + } + } else { + if len(rec.Value) != 0 { + return fmt.Errorf("expected the first reader[txid: %d, key: %s] read nil, \nbut got %x", + rec.Txid, rec.Key, rec.Value) + } + } + } + } + + return nil +}