diff --git a/aof.go b/aof.go index 215f4d7..57188c5 100644 --- a/aof.go +++ b/aof.go @@ -62,6 +62,5 @@ func (a *Aof) Read(fn func(args []resp.RESP)) error { } fn(args) } - return nil } diff --git a/command.go b/command.go index 1d40df9..28625a1 100644 --- a/command.go +++ b/command.go @@ -3,6 +3,7 @@ package main import ( "fmt" "github.com/xgzlucario/rotom/internal/resp" + "github.com/xgzlucario/rotom/internal/timer" "strconv" "strings" "time" @@ -109,7 +110,7 @@ func setCommand(writer *resp.Writer, args []resp.RESP) { writer.WriteError(errParseInteger) return } - ttl = time.Now().Add(n * time.Second).UnixNano() + ttl = timer.GetNanoTime() + int64(n*time.Second) extra = extra[2:] // PX @@ -119,7 +120,7 @@ func setCommand(writer *resp.Writer, args []resp.RESP) { writer.WriteError(errParseInteger) return } - ttl = time.Now().Add(n * time.Millisecond).UnixNano() + ttl = timer.GetNanoTime() + int64(n*time.Millisecond) extra = extra[2:] // KEEPTTL diff --git a/dict.go b/dict.go index 7f3a44d..ab41990 100644 --- a/dict.go +++ b/dict.go @@ -1,39 +1,44 @@ package main import ( + "github.com/cockroachdb/swiss" + "github.com/xgzlucario/rotom/internal/timer" "time" ) // Dict is the hashmap for rotom. type Dict struct { - data map[string]any - expire map[string]int64 + data *swiss.Map[string, any] + expire *swiss.Map[string, int64] +} + +func init() { + timer.Init() } func New() *Dict { return &Dict{ - data: make(map[string]any, 64), - expire: make(map[string]int64, 64), + data: swiss.New[string, any](64), + expire: swiss.New[string, int64](64), } } func (dict *Dict) Get(key string) (any, int) { - data, ok := dict.data[key] + data, ok := dict.data.Get(key) if !ok { // key not exist return nil, KEY_NOT_EXIST } - ts, ok := dict.expire[key] + ts, ok := dict.expire.Get(key) if !ok { return data, TTL_FOREVER } // key expired - now := time.Now().UnixNano() + now := timer.GetNanoTime() if ts < now { - delete(dict.data, key) - delete(dict.expire, key) + dict.delete(key) return nil, KEY_NOT_EXIST } @@ -41,23 +46,27 @@ func (dict *Dict) Get(key string) (any, int) { } func (dict *Dict) Set(key string, data any) { - dict.data[key] = data + dict.data.Put(key, data) } func (dict *Dict) SetWithTTL(key string, data any, ttl int64) { if ttl > 0 { - dict.expire[key] = ttl + dict.expire.Put(key, ttl) } - dict.data[key] = data + dict.data.Put(key, data) +} + +func (dict *Dict) delete(key string) { + dict.data.Delete(key) + dict.expire.Delete(key) } func (dict *Dict) Delete(key string) bool { - _, ok := dict.data[key] + _, ok := dict.data.Get(key) if !ok { return false } - delete(dict.data, key) - delete(dict.expire, key) + dict.delete(key) return true } @@ -65,36 +74,31 @@ func (dict *Dict) Delete(key string) bool { // return `0` if key not exist or expired. // return `1` if set success. func (dict *Dict) SetTTL(key string, ttl int64) int { - _, ok := dict.data[key] + _, ok := dict.data.Get(key) if !ok { // key not exist return 0 } // check key if already expired - ts, ok := dict.expire[key] - if ok && ts < time.Now().UnixNano() { - delete(dict.data, key) - delete(dict.expire, key) + ts, ok := dict.expire.Get(key) + if ok && ts < timer.GetNanoTime() { + dict.delete(key) return 0 } // set ttl - dict.expire[key] = ttl + dict.expire.Put(key, ttl) return 1 } func (dict *Dict) EvictExpired() { var count int - now := time.Now().UnixNano() - for key, ts := range dict.expire { - if now > ts { - delete(dict.expire, key) - delete(dict.data, key) + dict.expire.All(func(key string, ts int64) bool { + if timer.GetNanoTime() > ts { + dict.Delete(key) } count++ - if count > 20 { - return - } - } + return count <= 20 + }) } diff --git a/go.mod b/go.mod index f39bc91..db4133e 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/xgzlucario/rotom go 1.23 require ( + github.com/alicebob/miniredis/v2 v2.33.0 github.com/bytedance/sonic v1.12.4 github.com/chen3feng/stl4go v0.1.1 github.com/cockroachdb/swiss v0.0.0-20240612210725-f4de07ae6964 @@ -19,7 +20,6 @@ require ( require ( github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect - github.com/alicebob/miniredis/v2 v2.33.0 // indirect github.com/bytedance/sonic/loader v0.2.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect diff --git a/internal/resp/resp_test.go b/internal/resp/resp_test.go index 8c39736..df7763e 100644 --- a/internal/resp/resp_test.go +++ b/internal/resp/resp_test.go @@ -130,7 +130,10 @@ func TestReader(t *testing.T) { } func FuzzRESPReader(f *testing.F) { - f.Fuzz(func(t *testing.T, b []byte) { - NewReader(b).ReadNextCommand(nil) + f.Fuzz(func(t *testing.T, i uint8, buf string) { + reader := NewReader([]byte(buf)) + for range i { + reader.ReadNextCommand(nil) + } }) } diff --git a/internal/timer/timer.go b/internal/timer/timer.go new file mode 100644 index 0000000..91fad75 --- /dev/null +++ b/internal/timer/timer.go @@ -0,0 +1,23 @@ +package timer + +import ( + "sync/atomic" + "time" +) + +var ( + nanotime atomic.Int64 +) + +func Init() { + go func() { + tk := time.NewTicker(time.Millisecond / 10) + for t := range tk.C { + nanotime.Store(t.UnixNano()) + } + }() +} + +func GetNanoTime() int64 { + return nanotime.Load() +} diff --git a/rdb.go b/rdb.go index 6eee2a0..5a9ddb2 100644 --- a/rdb.go +++ b/rdb.go @@ -14,12 +14,10 @@ type Rdb struct { } func NewRdb(path string) *Rdb { - return &Rdb{ - path: path, - } + return &Rdb{path: path} } -func (r *Rdb) SaveDB() error { +func (r *Rdb) SaveDB() (err error) { // create tmp file fname := fmt.Sprintf("%s.rdb", time.Now().Format(time.RFC3339)) fs, err := os.Create(fname) @@ -28,13 +26,14 @@ func (r *Rdb) SaveDB() error { } writer := resp.NewWriter(MB) - writer.WriteArrayHead(len(db.dict.data)) + writer.WriteArrayHead(db.dict.data.Len()) - for k, v := range db.dict.data { + db.dict.data.All(func(k string, v any) bool { // format: {objectType,ttl,key,value} objectType := getObjectType(v) writer.WriteInteger(int(objectType)) - writer.WriteInteger(int(db.dict.expire[k])) + ttl, _ := db.dict.expire.Get(k) + writer.WriteInteger(int(ttl)) writer.WriteBulkString(k) switch objectType { @@ -44,10 +43,12 @@ func (r *Rdb) SaveDB() error { writer.WriteInteger(v.(int)) default: if err = v.(iface.Encoder).Encode(writer); err != nil { - return err + log.Error().Msgf("[rdb] encode error: %v, %v", objectType, err) + return false } } - } + return true + }) // flush _, err = writer.FlushTo(fs)