diff --git a/batch.go b/batch.go index abe2c784..26b0c3e7 100644 --- a/batch.go +++ b/batch.go @@ -3,9 +3,9 @@ package rosedb import ( "fmt" "sync" + "time" "github.com/bwmarrin/snowflake" - "github.com/rosedblabs/wal" ) @@ -112,9 +112,35 @@ func (b *Batch) Put(key []byte, value []byte) error { b.mu.Lock() // write to pendingWrites b.pendingWrites[string(key)] = &LogRecord{ - Key: key, - Value: value, - Type: LogRecordNormal, + Key: key, + Value: value, + Type: LogRecordNormal, + Expire: 0, + } + b.mu.Unlock() + + return nil +} + +// PutWithTTL adds a key-value pair with ttl to the batch for writing. +func (b *Batch) PutWithTTL(key []byte, value []byte, ttl time.Duration) error { + if len(key) == 0 { + return ErrKeyIsEmpty + } + if b.db.closed { + return ErrDBClosed + } + if b.options.ReadOnly { + return ErrReadOnlyBatch + } + + b.mu.Lock() + // write to pendingWrites + b.pendingWrites[string(key)] = &LogRecord{ + Key: key, + Value: value, + Type: LogRecordNormal, + Expire: time.Now().Add(ttl).UnixNano(), } b.mu.Unlock() @@ -130,11 +156,12 @@ func (b *Batch) Get(key []byte) ([]byte, error) { return nil, ErrDBClosed } + now := time.Now().UnixNano() // get from pendingWrites if b.pendingWrites != nil { b.mu.RLock() if record := b.pendingWrites[string(key)]; record != nil { - if record.Type == LogRecordDeleted { + if record.Type == LogRecordDeleted || (record.Expire > 0 && record.Expire <= now) { b.mu.RUnlock() return nil, ErrKeyNotFound } @@ -154,10 +181,14 @@ func (b *Batch) Get(key []byte) ([]byte, error) { return nil, err } + // check if the record is deleted or expired record := decodeLogRecord(chunk) if record.Type == LogRecordDeleted { panic("Deleted data cannot exist in the index") } + if record.Expire > 0 && record.Expire <= now { + return nil, ErrKeyNotFound + } return record.Value, nil } @@ -207,9 +238,24 @@ func (b *Batch) Exist(key []byte) (bool, error) { b.mu.RUnlock() } - // check if the key exists in data file + // check if the key exists in index position := b.db.index.Get(key) - return position != nil, nil + if position == nil { + return false, nil + } + + // check if the record is deleted or expired + chunk, err := b.db.dataFiles.Read(position) + if err != nil { + return false, err + } + + now := time.Now().UnixNano() + record := decodeLogRecord(chunk) + if record.Type == LogRecordDeleted || (record.Expire > 0 && record.Expire <= now) { + return false, nil + } + return true, nil } // Commit commits the batch, if the batch is readonly or empty, it will return directly. @@ -241,8 +287,14 @@ func (b *Batch) Commit() error { batchId := b.batchId.Generate() positions := make(map[string]*wal.ChunkPosition) + now := time.Now().UnixNano() // write to wal for _, record := range b.pendingWrites { + // skip the expired record + if record.Expire > 0 && record.Expire <= now { + continue + } + record.BatchId = uint64(batchId) encRecord := encodeLogRecord(record) pos, err := b.db.dataFiles.Write(encRecord) @@ -291,7 +343,7 @@ func (b *Batch) Commit() error { return nil } -// Rollback discards a uncommitted batch instance. +// Rollback discards an uncommitted batch instance. // the discard operation will clear the buffered data and release the lock. func (b *Batch) Rollback() error { defer b.unlock() diff --git a/db.go b/db.go index b80e2515..576534c1 100644 --- a/db.go +++ b/db.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "sync" + "time" "github.com/bwmarrin/snowflake" "github.com/gofrs/flock" @@ -233,6 +234,25 @@ func (db *DB) Put(key []byte, value []byte) error { return batch.Commit() } +// PutWithTTL a key-value pair into the database, with a ttl. +// Actually, it will open a new batch and commit it. +// You can think the batch has only one PutWithTTL operation. +func (db *DB) PutWithTTL(key []byte, value []byte, ttl time.Duration) error { + batch := db.batchPool.Get().(*Batch) + defer func() { + batch.reset() + db.batchPool.Put(batch) + }() + // This is a single delete operation, we can set Sync to false. + // Because the data will be written to the WAL, + // and the WAL file will be synced to disk according to the DB options. + batch.init(false, false, db).withPendingWrites() + if err := batch.PutWithTTL(key, value, ttl); err != nil { + return err + } + return batch.Commit() +} + // Get the value of the specified key from the database. // Actually, it will open a new batch and commit it. // You can think the batch has only one Get operation. @@ -294,11 +314,15 @@ func (db *DB) Ascend(handleFn func(k []byte, v []byte) (bool, error)) { defer db.mu.RUnlock() db.index.Ascend(func(key []byte, pos *wal.ChunkPosition) (bool, error) { - val, err := db.dataFiles.Read(pos) + chunk, err := db.dataFiles.Read(pos) if err != nil { - return false, nil + return false, err } - return handleFn(key, val) + value, err := db.checkValue(chunk) + if err != nil { + return false, err + } + return handleFn(key, value) }) } @@ -308,11 +332,15 @@ func (db *DB) AscendRange(startKey, endKey []byte, handleFn func(k []byte, v []b defer db.mu.RUnlock() db.index.AscendRange(startKey, endKey, func(key []byte, pos *wal.ChunkPosition) (bool, error) { - val, err := db.dataFiles.Read(pos) + chunk, err := db.dataFiles.Read(pos) if err != nil { return false, nil } - return handleFn(key, val) + value, err := db.checkValue(chunk) + if err != nil { + return false, err + } + return handleFn(key, value) }) } @@ -322,11 +350,15 @@ func (db *DB) AscendGreaterOrEqual(key []byte, handleFn func(k []byte, v []byte) defer db.mu.RUnlock() db.index.AscendGreaterOrEqual(key, func(key []byte, pos *wal.ChunkPosition) (bool, error) { - val, err := db.dataFiles.Read(pos) + chunk, err := db.dataFiles.Read(pos) if err != nil { return false, nil } - return handleFn(key, val) + value, err := db.checkValue(chunk) + if err != nil { + return false, err + } + return handleFn(key, value) }) } @@ -336,11 +368,15 @@ func (db *DB) Descend(handleFn func(k []byte, v []byte) (bool, error)) { defer db.mu.RUnlock() db.index.Descend(func(key []byte, pos *wal.ChunkPosition) (bool, error) { - val, err := db.dataFiles.Read(pos) + chunk, err := db.dataFiles.Read(pos) if err != nil { return false, nil } - return handleFn(key, val) + value, err := db.checkValue(chunk) + if err != nil { + return false, err + } + return handleFn(key, value) }) } @@ -350,11 +386,15 @@ func (db *DB) DescendRange(startKey, endKey []byte, handleFn func(k []byte, v [] defer db.mu.RUnlock() db.index.DescendRange(startKey, endKey, func(key []byte, pos *wal.ChunkPosition) (bool, error) { - val, err := db.dataFiles.Read(pos) + chunk, err := db.dataFiles.Read(pos) if err != nil { return false, nil } - return handleFn(key, val) + value, err := db.checkValue(chunk) + if err != nil { + return false, err + } + return handleFn(key, value) }) } @@ -364,14 +404,27 @@ func (db *DB) DescendLessOrEqual(key []byte, handleFn func(k []byte, v []byte) ( defer db.mu.RUnlock() db.index.DescendLessOrEqual(key, func(key []byte, pos *wal.ChunkPosition) (bool, error) { - val, err := db.dataFiles.Read(pos) + chunk, err := db.dataFiles.Read(pos) if err != nil { return false, nil } - return handleFn(key, val) + value, err := db.checkValue(chunk) + if err != nil { + return false, err + } + return handleFn(key, value) }) } +func (db *DB) checkValue(chunk []byte) ([]byte, error) { + record := decodeLogRecord(chunk) + now := time.Now().UnixNano() + if record.Type == LogRecordDeleted || (record.Expire > 0 && record.Expire <= now) { + return nil, ErrKeyNotFound + } + return record.Value, nil +} + func checkOptions(options Options) error { if options.DirPath == "" { return errors.New("database dir path is empty") @@ -391,6 +444,7 @@ func (db *DB) loadIndexFromWAL() error { return err } indexRecords := make(map[uint64][]*IndexRecord) + now := time.Now().UnixNano() // get a reader for WAL reader := db.dataFiles.NewReader() for { @@ -435,6 +489,10 @@ func (db *DB) loadIndexFromWAL() error { // so put the record into index directly. db.index.Put(record.Key, position) } else { + // expired records should not be indexed + if record.Expire > 0 && record.Expire <= now { + continue + } // put the record into the temporary indexRecords indexRecords[record.BatchId] = append(indexRecords[record.BatchId], &IndexRecord{ diff --git a/db_test.go b/db_test.go index 65eca34c..775aa94b 100644 --- a/db_test.go +++ b/db_test.go @@ -4,6 +4,7 @@ import ( "math/rand" "sync" "testing" + "time" "github.com/rosedblabs/rosedb/v2/utils" "github.com/stretchr/testify/assert" @@ -356,3 +357,75 @@ func TestDB_DescendLessOrEqual(t *testing.T) { }) assert.Equal(t, []string{"grape", "date", "cherry", "banana", "apple"}, resultDescendLessOrEqual) } + +func TestDB_PutWithTTL(t *testing.T) { + options := DefaultOptions + db, err := Open(options) + assert.Nil(t, err) + defer destroyDB(db) + + err = db.PutWithTTL(utils.GetTestKey(1), utils.RandomValue(128), time.Millisecond*100) + assert.Nil(t, err) + val1, err := db.Get(utils.GetTestKey(1)) + assert.Nil(t, err) + assert.NotNil(t, val1) + time.Sleep(time.Millisecond * 200) + val2, err := db.Get(utils.GetTestKey(1)) + assert.Equal(t, err, ErrKeyNotFound) + assert.Nil(t, val2) + + err = db.PutWithTTL(utils.GetTestKey(2), utils.RandomValue(128), time.Millisecond*200) + // rewrite + err = db.Put(utils.GetTestKey(2), utils.RandomValue(128)) + assert.Nil(t, err) + time.Sleep(time.Millisecond * 200) + val3, err := db.Get(utils.GetTestKey(2)) + assert.Nil(t, err) + assert.NotNil(t, val3) + + err = db.Close() + assert.Nil(t, err) + + db2, err := Open(options) + assert.Nil(t, err) + + val4, err := db2.Get(utils.GetTestKey(1)) + assert.Equal(t, err, ErrKeyNotFound) + assert.Nil(t, val4) + + val5, err := db2.Get(utils.GetTestKey(2)) + assert.Nil(t, err) + assert.NotNil(t, val5) + + _ = db2.Close() +} + +func TestDB_PutWithTTL_Merge(t *testing.T) { + options := DefaultOptions + db, err := Open(options) + assert.Nil(t, err) + defer destroyDB(db) + for i := 0; i < 100; i++ { + err = db.PutWithTTL(utils.GetTestKey(i), utils.RandomValue(10), time.Second*2) + assert.Nil(t, err) + } + for i := 100; i < 150; i++ { + err = db.PutWithTTL(utils.GetTestKey(i), utils.RandomValue(10), time.Second*20) + assert.Nil(t, err) + } + time.Sleep(time.Second * 3) + + err = db.Merge(true) + assert.Nil(t, err) + + for i := 0; i < 100; i++ { + val, err := db.Get(utils.GetTestKey(i)) + assert.Nil(t, val) + assert.Equal(t, err, ErrKeyNotFound) + } + for i := 100; i < 150; i++ { + val, err := db.Get(utils.GetTestKey(i)) + assert.Nil(t, err) + assert.NotNil(t, val) + } +} diff --git a/merge.go b/merge.go index fc94a299..cbc63760 100644 --- a/merge.go +++ b/merge.go @@ -3,13 +3,14 @@ package rosedb import ( "encoding/binary" "fmt" + "github.com/rosedblabs/rosedb/v2/index" + "github.com/rosedblabs/wal" "io" "math" "os" "path/filepath" "sync/atomic" - - "github.com/rosedblabs/wal" + "time" ) const ( @@ -51,6 +52,8 @@ func (db *DB) Merge(reopenAfterDone bool) error { return err } + // discard the old index first. + db.index = index.NewIndexer() // rebuild index if err = db.loadIndex(); err != nil { return err @@ -103,6 +106,7 @@ func (db *DB) doMerge() error { _ = mergeDB.Close() }() + now := time.Now().UnixNano() // iterate all the data files, and write the valid data to the new data file. reader := db.dataFiles.NewReaderWithMax(prevActiveSegId) for { @@ -116,7 +120,7 @@ func (db *DB) doMerge() error { record := decodeLogRecord(chunk) // Only handle the normal log record, LogRecordDeleted and LogRecordBatchFinished // will be ignored, because they are not valid data. - if record.Type == LogRecordNormal { + if record.Type == LogRecordNormal && (record.Expire == 0 || record.Expire > now) { db.mu.RLock() indexPos := db.index.Get(record.Key) db.mu.RUnlock() @@ -222,40 +226,40 @@ func encodeHintRecord(key []byte, pos *wal.ChunkPosition) []byte { // 5 5 10 5 = 25 // see binary.MaxVarintLen64 and binary.MaxVarintLen32 buf := make([]byte, 25) - var index = 0 + var idx = 0 // SegmentId - index += binary.PutUvarint(buf[index:], uint64(pos.SegmentId)) + idx += binary.PutUvarint(buf[idx:], uint64(pos.SegmentId)) // BlockNumber - index += binary.PutUvarint(buf[index:], uint64(pos.BlockNumber)) + idx += binary.PutUvarint(buf[idx:], uint64(pos.BlockNumber)) // ChunkOffset - index += binary.PutUvarint(buf[index:], uint64(pos.ChunkOffset)) + idx += binary.PutUvarint(buf[idx:], uint64(pos.ChunkOffset)) // ChunkSize - index += binary.PutUvarint(buf[index:], uint64(pos.ChunkSize)) + idx += binary.PutUvarint(buf[idx:], uint64(pos.ChunkSize)) // key - result := make([]byte, index+len(key)) - copy(result, buf[:index]) - copy(result[index:], key) + result := make([]byte, idx+len(key)) + copy(result, buf[:idx]) + copy(result[idx:], key) return result } func decodeHintRecord(buf []byte) ([]byte, *wal.ChunkPosition) { - var index = 0 + var idx = 0 // SegmentId - segmentId, n := binary.Uvarint(buf[index:]) - index += n + segmentId, n := binary.Uvarint(buf[idx:]) + idx += n // BlockNumber - blockNumber, n := binary.Uvarint(buf[index:]) - index += n + blockNumber, n := binary.Uvarint(buf[idx:]) + idx += n // ChunkOffset - chunkOffset, n := binary.Uvarint(buf[index:]) - index += n + chunkOffset, n := binary.Uvarint(buf[idx:]) + idx += n // ChunkSize - chunkSize, n := binary.Uvarint(buf[index:]) - index += n + chunkSize, n := binary.Uvarint(buf[idx:]) + idx += n // Key - key := buf[index:] + key := buf[idx:] return key, &wal.ChunkPosition{ SegmentId: wal.SegmentID(segmentId), diff --git a/record.go b/record.go index d6d4410f..91537cf6 100644 --- a/record.go +++ b/record.go @@ -18,10 +18,10 @@ const ( LogRecordBatchFinished ) -// type batchId keySize valueSize +// type batchId keySize valueSize expire // -// 1 + 10 + 5 + 5 = 21 -const maxLogRecordHeaderSize = binary.MaxVarintLen32*2 + binary.MaxVarintLen64 + 1 +// 1 + 10 + 5 + 5 + 10 = 31 +const maxLogRecordHeaderSize = binary.MaxVarintLen32*2 + binary.MaxVarintLen64*2 + 1 // LogRecord is the log record of the key/value pair. // It contains the key, the value, the record type and the batch id @@ -31,6 +31,7 @@ type LogRecord struct { Value []byte Type LogRecordType BatchId uint64 + Expire int64 } // IndexRecord is the index record of the key. @@ -42,11 +43,11 @@ type IndexRecord struct { position *wal.ChunkPosition } -// +-------------+-------------+-------------+--------------+-------------+--------------+ -// | type | batch id | key size | value size | key | value | -// +-------------+-------------+-------------+--------------+-------------+--------------+ +// +-------------+-------------+-------------+--------------+---------------+---------+--------------+ +// | type | batch id | key size | value size | expire | key | value | +// +-------------+-------------+-------------+--------------+---------------+--------+--------------+ // -// 1 byte varint(max 10) varint(max 5) varint(max 5) varint varint +// 1 byte varint(max 10) varint(max 5) varint(max 5) varint(max 10) varint varint func encodeLogRecord(logRecord *LogRecord) []byte { header := make([]byte, maxLogRecordHeaderSize) @@ -59,6 +60,8 @@ func encodeLogRecord(logRecord *LogRecord) []byte { index += binary.PutVarint(header[index:], int64(len(logRecord.Key))) // value size index += binary.PutVarint(header[index:], int64(len(logRecord.Value))) + // expire + index += binary.PutVarint(header[index:], logRecord.Expire) var size = index + len(logRecord.Key) + len(logRecord.Value) encBytes := make([]byte, size) @@ -90,6 +93,10 @@ func decodeLogRecord(buf []byte) *LogRecord { valueSize, n := binary.Varint(buf[index:]) index += uint32(n) + // expire + expire, n := binary.Varint(buf[index:]) + index += uint32(n) + // copy key key := make([]byte, keySize) copy(key[:], buf[index:index+uint32(keySize)]) @@ -99,6 +106,6 @@ func decodeLogRecord(buf []byte) *LogRecord { value := make([]byte, valueSize) copy(value[:], buf[index:index+uint32(valueSize)]) - return &LogRecord{Key: key, Value: value, + return &LogRecord{Key: key, Value: value, Expire: expire, BatchId: batchId, Type: recordType} }