diff --git a/.gitignore b/.gitignore index ec5af87..4b38dd8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ coverage.* *.aof - +.idea/ redis rotom \ No newline at end of file diff --git a/Makefile b/Makefile index dc6523c..486d7e9 100644 --- a/Makefile +++ b/Makefile @@ -5,16 +5,17 @@ run-gc: GODEBUG=gctrace=1 go run . test-cover: + rm -f *.aof go test ./... -race -coverprofile=coverage.txt -covermode=atomic go tool cover -html=coverage.txt -o coverage.html rm coverage.txt - rm *.aof + rm -f *.aof fuzz-test: go test -fuzz=FuzzRESPReader pprof: - go tool pprof -http=:18081 "http://192.168.1.6:6060/debug/pprof/profile?seconds=30" + go tool pprof -http=:18081 "http://192.168.10.139:6060/debug/pprof/profile?seconds=30" heap: go tool pprof http://192.168.1.6:6060/debug/pprof/heap @@ -26,8 +27,5 @@ build: CGO_ENABLED=0 \ go build -o rotom -ldflags "-s -w -X main.buildTime=$(shell date +%y%m%d_%H%M%S%z)" -upx: - upx -9 rotom - build-docker: docker build --build-arg BUILD_TIME=$(shell date +%y%m%d_%H%M%S%z) -t rotom . \ No newline at end of file diff --git a/ae.go b/ae.go index 8c16f88..d528de2 100644 --- a/ae.go +++ b/ae.go @@ -1,6 +1,7 @@ package main import ( + "errors" "github.com/xgzlucario/rotom/internal/dict" "golang.org/x/sys/unix" ) @@ -38,7 +39,7 @@ type AeLoop struct { timeEventNextId int stop bool - _fevents []*AeFileEvent // fes cache + events []*AeFileEvent // file events cache } func (loop *AeLoop) AddRead(fd int, proc FileProc, extra interface{}) { @@ -141,12 +142,12 @@ func AeLoopCreate() (*AeLoop, error) { fileEventFd: epollFd, timeEventNextId: 1, stop: false, - _fevents: make([]*AeFileEvent, 128), // pre alloc + events: make([]*AeFileEvent, 128), // pre alloc }, nil } func (loop *AeLoop) nearestTime() int64 { - var nearest int64 = GetMsTime() + 1000 + var nearest = GetMsTime() + 1000 p := loop.TimeEvents for p != nil { nearest = min(nearest, p.when) @@ -166,7 +167,7 @@ retry: n, err := unix.EpollWait(loop.fileEventFd, events[:], int(timeout)) if err != nil { // interrupted system call - if err == unix.EINTR { + if errors.Is(err, unix.EINTR) { goto retry } log.Error().Msgf("epoll wait error: %v", err) @@ -174,7 +175,7 @@ retry: } // collect file events - fes = loop._fevents[:0] + fes = loop.events[:0] for _, ev := range events[:n] { if ev.Events&unix.EPOLLIN != 0 { fe := loop.FileEvents[int(ev.Fd)] diff --git a/command.go b/command.go index 50a245a..f0b9eb9 100644 --- a/command.go +++ b/command.go @@ -1,625 +1,617 @@ -package main - -import ( - "fmt" - "strconv" - "strings" - "time" - - "github.com/xgzlucario/rotom/internal/dict" - "github.com/xgzlucario/rotom/internal/hash" - "github.com/xgzlucario/rotom/internal/list" - "github.com/xgzlucario/rotom/internal/zset" - lua "github.com/yuin/gopher-lua" -) - -var ( - WITH_SCORES = "WITHSCORES" - KEEP_TTL = "KEEPTTL" - NX = "NX" - EX = "EX" - PX = "PX" -) - -type Command struct { - // name is lowercase letters command name. - name string - - // handler is this command real database handler function. - handler func(writer *RESPWriter, args []RESP) - - // minArgsNum represents the minimal number of arguments that command accepts. - minArgsNum int - - // persist indicates whether this command needs to be persisted. - // effective when `appendonly` is true. - persist bool -} - -// cmdTable is the list of all available commands. -var cmdTable []*Command = []*Command{ - {"set", setCommand, 2, true}, - {"get", getCommand, 1, false}, - {"del", delCommand, 1, true}, - {"incr", incrCommand, 1, true}, - {"hset", hsetCommand, 3, true}, - {"hget", hgetCommand, 2, false}, - {"hdel", hdelCommand, 2, true}, - {"hgetall", hgetallCommand, 1, false}, - {"rpush", rpushCommand, 2, true}, - {"lpush", lpushCommand, 2, true}, - {"rpop", rpopCommand, 1, true}, - {"lpop", lpopCommand, 1, true}, - {"lrange", lrangeCommand, 3, false}, - {"sadd", saddCommand, 2, true}, - {"srem", sremCommand, 2, true}, - {"spop", spopCommand, 1, true}, - {"zadd", zaddCommand, 3, true}, - {"zrem", zremCommand, 2, true}, - {"zrank", zrankCommand, 2, false}, - {"zpopmin", zpopminCommand, 1, true}, - {"zrange", zrangeCommand, 3, false}, - {"eval", evalCommand, 2, true}, - {"ping", pingCommand, 0, false}, - {"flushdb", flushdbCommand, 0, true}, - // TODO - {"mset", todoCommand, 0, false}, - {"xadd", todoCommand, 0, false}, - {"client", todoCommand, 0, false}, -} - -func equalFold(a, b string) bool { - return len(a) == len(b) && strings.EqualFold(a, b) -} - -func lookupCommand(name string) (*Command, error) { - for _, c := range cmdTable { - if equalFold(name, c.name) { - return c, nil - } - } - return nil, fmt.Errorf("%w '%s'", errUnknownCommand, name) -} - -func (cmd *Command) processCommand(writer *RESPWriter, args []RESP) { - if len(args) < cmd.minArgsNum { - writer.WriteError(errWrongArguments) - return - } - cmd.handler(writer, args) -} - -func pingCommand(writer *RESPWriter, _ []RESP) { - writer.WriteString("PONG") -} - -func setCommand(writer *RESPWriter, args []RESP) { - key := args[0].ToString() - value := args[1].Clone() - extra := args[2:] - var ttl int64 - - for len(extra) > 0 { - arg := extra[0].ToStringUnsafe() - - // EX - if equalFold(arg, EX) && len(extra) >= 2 { - n, err := extra[1].ToInt() - if err != nil { - writer.WriteError(errParseInteger) - return - } - ttl = dict.GetNanoTime() + int64(time.Second)*int64(n) - extra = extra[2:] - - // PX - } else if equalFold(arg, PX) && len(extra) >= 2 { - n, err := extra[1].ToInt() - if err != nil { - writer.WriteError(errParseInteger) - return - } - ttl = dict.GetNanoTime() + int64(time.Millisecond)*int64(n) - extra = extra[2:] - - // KEEPTTL - } else if equalFold(arg, KEEP_TTL) { - extra = extra[1:] - ttl = -1 - - // NX - } else if equalFold(arg, NX) { - if _, ttl := db.dict.Get(key); ttl != dict.KEY_NOT_EXIST { - writer.WriteNull() - return - } - extra = extra[1:] - - } else { - writer.WriteError(errSyntax) - return - } - } - - db.dict.SetWithTTL(key, value, ttl) - writer.WriteString("OK") -} - -func incrCommand(writer *RESPWriter, args []RESP) { - key := args[0].ToStringUnsafe() - - object, ttl := db.dict.Get(key) - if ttl == dict.KEY_NOT_EXIST { - db.dict.Set(strings.Clone(key), 1) - writer.WriteInteger(1) - return - } - - switch v := object.(type) { - case int: - num := v + 1 - writer.WriteInteger(num) - db.dict.Set(strings.Clone(key), num) - - case []byte: - // conv to integer - num, err := RESP(v).ToInt() - if err != nil { - writer.WriteError(errParseInteger) - return - } - num++ - strconv.AppendInt(v[:0], int64(num), 10) - writer.WriteInteger(num) - - default: - writer.WriteError(errWrongType) - } -} - -func getCommand(writer *RESPWriter, args []RESP) { - key := args[0].ToStringUnsafe() - object, ttl := db.dict.Get(key) - if ttl == dict.KEY_NOT_EXIST { - writer.WriteNull() - return - } - switch v := object.(type) { - case int: - writer.WriteBulkString(strconv.Itoa(v)) - case []byte: - writer.WriteBulk(v) - default: - writer.WriteError(errWrongType) - } -} - -func delCommand(writer *RESPWriter, args []RESP) { - var count int - for _, arg := range args { - if db.dict.Delete(arg.ToStringUnsafe()) { - count++ - } - } - writer.WriteInteger(count) -} - -func hsetCommand(writer *RESPWriter, args []RESP) { - hash := args[0] - args = args[1:] - - if len(args)%2 == 1 { - writer.WriteError(errWrongArguments) - return - } - - hmap, err := fetchMap(hash, true) - if err != nil { - writer.WriteError(err) - return - } - - var count int - for i := 0; i < len(args); i += 2 { - key := args[i].ToString() - value := args[i+1].Clone() - if hmap.Set(key, value) { - count++ - } - } - writer.WriteInteger(count) -} - -func hgetCommand(writer *RESPWriter, args []RESP) { - hash := args[0] - key := args[1].ToStringUnsafe() - hmap, err := fetchMap(hash) - if err != nil { - writer.WriteError(errWrongType) - return - } - value, ok := hmap.Get(key) - if ok { - writer.WriteBulk(value) - } else { - writer.WriteNull() - } -} - -func hdelCommand(writer *RESPWriter, args []RESP) { - hash := args[0] - keys := args[1:] - hmap, err := fetchMap(hash) - if err != nil { - writer.WriteError(err) - return - } - var count int - for _, v := range keys { - if hmap.Remove(v.ToStringUnsafe()) { - count++ - } - } - writer.WriteInteger(count) -} - -func hgetallCommand(writer *RESPWriter, args []RESP) { - hash := args[0] - hmap, err := fetchMap(hash) - if err != nil { - writer.WriteError(err) - return - } - writer.WriteArrayHead(hmap.Len() * 2) - hmap.Scan(func(key string, value []byte) { - writer.WriteBulkString(key) - writer.WriteBulk(value) - }) -} - -func lpushCommand(writer *RESPWriter, args []RESP) { - key := args[0] - ls, err := fetchList(key, true) - if err != nil { - writer.WriteError(err) - return - } - for _, arg := range args[1:] { - ls.LPush(arg.ToStringUnsafe()) - } - writer.WriteInteger(ls.Size()) -} - -func rpushCommand(writer *RESPWriter, args []RESP) { - key := args[0] - ls, err := fetchList(key, true) - if err != nil { - writer.WriteError(err) - return - } - for _, arg := range args[1:] { - ls.RPush(arg.ToStringUnsafe()) - } - writer.WriteInteger(ls.Size()) -} - -func lpopCommand(writer *RESPWriter, args []RESP) { - key := args[0] - ls, err := fetchList(key) - if err != nil { - writer.WriteError(err) - return - } - val, ok := ls.LPop() - if ok { - writer.WriteBulkString(val) - } else { - writer.WriteNull() - } -} - -func rpopCommand(writer *RESPWriter, args []RESP) { - key := args[0] - ls, err := fetchList(key) - if err != nil { - writer.WriteError(err) - return - } - val, ok := ls.RPop() - if ok { - writer.WriteBulkString(val) - } else { - writer.WriteNull() - } -} - -func lrangeCommand(writer *RESPWriter, args []RESP) { - key := args[0] - start, err := args[1].ToInt() - if err != nil { - writer.WriteError(err) - return - } - stop, err := args[2].ToInt() - if err != nil { - writer.WriteError(err) - return - } - ls, err := fetchList(key) - if err != nil { - writer.WriteError(err) - return - } - writer.WriteArrayHead(ls.RangeCount(start, stop)) - ls.Range(start, stop, func(data []byte) { - writer.WriteBulk(data) - }) -} - -func saddCommand(writer *RESPWriter, args []RESP) { - key := args[0] - set, err := fetchSet(key, true) - if err != nil { - writer.WriteError(err) - return - } - var count int - for _, arg := range args[1:] { - if set.Add(arg.ToString()) { - count++ - } - } - writer.WriteInteger(count) -} - -func sremCommand(writer *RESPWriter, args []RESP) { - key := args[0] - set, err := fetchSet(key) - if err != nil { - writer.WriteError(err) - return - } - var count int - for _, arg := range args[1:] { - if set.Remove(arg.ToStringUnsafe()) { - count++ - } - } - writer.WriteInteger(count) -} - -func spopCommand(writer *RESPWriter, args []RESP) { - key := args[0] - set, err := fetchSet(key) - if err != nil { - writer.WriteError(err) - return - } - member, ok := set.Pop() - if ok { - writer.WriteBulkString(member) - } else { - writer.WriteNull() - } -} - -func zaddCommand(writer *RESPWriter, args []RESP) { - key := args[0] - args = args[1:] - - zset, err := fetchZSet(key, true) - if err != nil { - writer.WriteError(err) - return - } - - var count int - for i := 0; i < len(args); i += 2 { - score, err := args[i].ToFloat() - if err != nil { - writer.WriteError(err) - return - } - key := args[i+1].ToString() - if zset.Set(key, score) { - count++ - } - } - writer.WriteInteger(count) -} - -func zrankCommand(writer *RESPWriter, args []RESP) { - key := args[0] - member := args[1].ToStringUnsafe() - - zset, err := fetchZSet(key) - if err != nil { - writer.WriteError(err) - return - } - - rank, _ := zset.Rank(member) - if rank < 0 { - writer.WriteNull() - } else { - writer.WriteInteger(rank) - } -} - -func zremCommand(writer *RESPWriter, args []RESP) { - key := args[0] - zset, err := fetchZSet(key) - if err != nil { - writer.WriteError(err) - return - } - var count int - for _, arg := range args[1:] { - if zset.Remove(arg.ToStringUnsafe()) { - count++ - } - } - writer.WriteInteger(count) -} - -func zrangeCommand(writer *RESPWriter, args []RESP) { - key := args[0] - start, err := args[1].ToInt() - if err != nil { - writer.WriteError(err) - return - } - stop, err := args[2].ToInt() - if err != nil { - writer.WriteError(err) - return - } - zset, err := fetchZSet(key) - if err != nil { - writer.WriteError(err) - return - } - - if stop == -1 { - stop = zset.Len() - } - start = min(start, stop) - - withScores := len(args) == 4 && equalFold(args[3].ToStringUnsafe(), WITH_SCORES) - if withScores { - writer.WriteArrayHead((stop - start) * 2) - zset.Range(start, stop, func(key string, score float64) { - writer.WriteBulkString(key) - writer.WriteFloat(score) - }) - - } else { - writer.WriteArrayHead(stop - start) - zset.Range(start, stop, func(key string, _ float64) { - writer.WriteBulkString(key) - }) - } -} - -func zpopminCommand(writer *RESPWriter, args []RESP) { - key := args[0] - count := 1 - var err error - if len(args) > 1 { - count, err = args[1].ToInt() - if err != nil { - writer.WriteError(err) - return - } - } - - zset, err := fetchZSet(key) - if err != nil { - writer.WriteError(err) - return - } - - size := min(zset.Len(), count) - writer.WriteArrayHead(size * 2) - for range size { - key, score := zset.PopMin() - writer.WriteBulkString(key) - writer.WriteFloat(score) - } -} - -func flushdbCommand(writer *RESPWriter, _ []RESP) { - db.dict = dict.New() - writer.WriteString("OK") -} - -func evalCommand(writer *RESPWriter, args []RESP) { - L := server.lua - script := args[0].ToString() - - if err := L.DoString(script); err != nil { - writer.WriteError(err) - return - } - - var serialize func(isRoot bool, ret lua.LValue) - serialize = func(isRoot bool, ret lua.LValue) { - switch res := ret.(type) { - case lua.LString: - writer.WriteBulkString(res.String()) - - case lua.LNumber: - writer.WriteInteger(int(res)) // convert to integer - - case *lua.LTable: - writer.WriteArrayHead(res.Len()) - res.ForEach(func(index, value lua.LValue) { - serialize(false, value) - }) - - default: - writer.WriteNull() - } - - if isRoot && ret.Type() != lua.LTNil { - L.Pop(1) - } - } - serialize(true, L.Get(-1)) -} - -func todoCommand(writer *RESPWriter, _ []RESP) { - writer.WriteString("OK") -} - -func fetchMap(key []byte, setnx ...bool) (Map, error) { - return fetch(key, func() Map { return hash.NewZipMap() }, setnx...) -} - -func fetchList(key []byte, setnx ...bool) (List, error) { - return fetch(key, func() List { return list.New() }, setnx...) -} - -func fetchSet(key []byte, setnx ...bool) (Set, error) { - return fetch(key, func() Set { return hash.NewZipSet() }, setnx...) -} - -func fetchZSet(key []byte, setnx ...bool) (ZSet, error) { - return fetch(key, func() ZSet { return zset.NewZSet() }, setnx...) -} - -func fetch[T any](key []byte, new func() T, setnx ...bool) (T, error) { - object, ttl := db.dict.Get(b2s(key)) - - if ttl != dict.KEY_NOT_EXIST { - v, ok := object.(T) - if !ok { - return v, errWrongType - } - - // conversion zipped structure - if len(setnx) > 0 && setnx[0] { - switch data := object.(type) { - case *hash.ZipMap: - if data.Len() < 256 { - break - } - db.dict.Set(string(key), data.ToMap()) - - case *hash.ZipSet: - if data.Len() < 512 { - break - } - db.dict.Set(string(key), data.ToSet()) - } - } - return v, nil - } - - v := new() - if len(setnx) > 0 && setnx[0] { - db.dict.Set(string(key), v) - } - - return v, nil -} +package main + +import ( + "fmt" + "strconv" + "strings" + "time" + + "github.com/xgzlucario/rotom/internal/dict" + "github.com/xgzlucario/rotom/internal/hash" + "github.com/xgzlucario/rotom/internal/list" + "github.com/xgzlucario/rotom/internal/zset" + lua "github.com/yuin/gopher-lua" +) + +var ( + WITH_SCORES = "WITHSCORES" + KEEP_TTL = "KEEPTTL" + NX = "NX" + EX = "EX" + PX = "PX" +) + +type Command struct { + // name is lowercase letters command name. + name string + + // handler is this command real database handler function. + handler func(writer *RESPWriter, args []RESP) + + // minArgsNum represents the minimal number of arguments that command accepts. + minArgsNum int + + // persist indicates whether this command needs to be persisted. + // effective when `appendonly` is true. + persist bool +} + +// cmdTable is the list of all available commands. +var cmdTable = []*Command{ + {"set", setCommand, 2, true}, + {"get", getCommand, 1, false}, + {"del", delCommand, 1, true}, + {"incr", incrCommand, 1, true}, + {"hset", hsetCommand, 3, true}, + {"hget", hgetCommand, 2, false}, + {"hdel", hdelCommand, 2, true}, + {"hgetall", hgetallCommand, 1, false}, + {"rpush", rpushCommand, 2, true}, + {"lpush", lpushCommand, 2, true}, + {"rpop", rpopCommand, 1, true}, + {"lpop", lpopCommand, 1, true}, + {"lrange", lrangeCommand, 3, false}, + {"sadd", saddCommand, 2, true}, + {"srem", sremCommand, 2, true}, + {"spop", spopCommand, 1, true}, + {"zadd", zaddCommand, 3, true}, + {"zrem", zremCommand, 2, true}, + {"zrank", zrankCommand, 2, false}, + {"zpopmin", zpopminCommand, 1, true}, + {"zrange", zrangeCommand, 3, false}, + {"eval", evalCommand, 2, true}, + {"ping", pingCommand, 0, false}, + {"flushdb", flushdbCommand, 0, true}, +} + +func equalFold(a, b string) bool { + return len(a) == len(b) && strings.EqualFold(a, b) +} + +func lookupCommand(name string) (*Command, error) { + for _, c := range cmdTable { + if equalFold(name, c.name) { + return c, nil + } + } + return nil, fmt.Errorf("%w '%s'", errUnknownCommand, name) +} + +func (cmd *Command) processCommand(writer *RESPWriter, args []RESP) { + if len(args) < cmd.minArgsNum { + writer.WriteError(errWrongArguments) + return + } + cmd.handler(writer, args) +} + +func pingCommand(writer *RESPWriter, _ []RESP) { + writer.WriteString("PONG") +} + +func setCommand(writer *RESPWriter, args []RESP) { + key := args[0].ToString() + value := args[1].Clone() + extra := args[2:] + var ttl int64 + + for len(extra) > 0 { + arg := extra[0].ToStringUnsafe() + + // EX + if equalFold(arg, EX) && len(extra) >= 2 { + n, err := extra[1].ToInt() + if err != nil { + writer.WriteError(errParseInteger) + return + } + ttl = dict.GetNanoTime() + int64(time.Second)*int64(n) + extra = extra[2:] + + // PX + } else if equalFold(arg, PX) && len(extra) >= 2 { + n, err := extra[1].ToInt() + if err != nil { + writer.WriteError(errParseInteger) + return + } + ttl = dict.GetNanoTime() + int64(time.Millisecond)*int64(n) + extra = extra[2:] + + // KEEPTTL + } else if equalFold(arg, KEEP_TTL) { + extra = extra[1:] + ttl = -1 + + // NX + } else if equalFold(arg, NX) { + if _, ttl := db.dict.Get(key); ttl != dict.KEY_NOT_EXIST { + writer.WriteNull() + return + } + extra = extra[1:] + + } else { + writer.WriteError(errSyntax) + return + } + } + + db.dict.SetWithTTL(key, value, ttl) + writer.WriteString("OK") +} + +func incrCommand(writer *RESPWriter, args []RESP) { + key := args[0].ToStringUnsafe() + + object, ttl := db.dict.Get(key) + if ttl == dict.KEY_NOT_EXIST { + db.dict.Set(strings.Clone(key), 1) + writer.WriteInteger(1) + return + } + + switch v := object.(type) { + case int: + num := v + 1 + writer.WriteInteger(num) + db.dict.Set(strings.Clone(key), num) + + case []byte: + // conv to integer + num, err := RESP(v).ToInt() + if err != nil { + writer.WriteError(errParseInteger) + return + } + num++ + strconv.AppendInt(v[:0], int64(num), 10) + writer.WriteInteger(num) + + default: + writer.WriteError(errWrongType) + } +} + +func getCommand(writer *RESPWriter, args []RESP) { + key := args[0].ToStringUnsafe() + object, ttl := db.dict.Get(key) + if ttl == dict.KEY_NOT_EXIST { + writer.WriteNull() + return + } + switch v := object.(type) { + case int: + writer.WriteBulkString(strconv.Itoa(v)) + case []byte: + writer.WriteBulk(v) + default: + writer.WriteError(errWrongType) + } +} + +func delCommand(writer *RESPWriter, args []RESP) { + var count int + for _, arg := range args { + if db.dict.Delete(arg.ToStringUnsafe()) { + count++ + } + } + writer.WriteInteger(count) +} + +func hsetCommand(writer *RESPWriter, args []RESP) { + hash := args[0] + args = args[1:] + + if len(args)%2 == 1 { + writer.WriteError(errWrongArguments) + return + } + + hmap, err := fetchMap(hash, true) + if err != nil { + writer.WriteError(err) + return + } + + var count int + for i := 0; i < len(args); i += 2 { + key := args[i].ToString() + value := args[i+1].Clone() + if hmap.Set(key, value) { + count++ + } + } + writer.WriteInteger(count) +} + +func hgetCommand(writer *RESPWriter, args []RESP) { + hash := args[0] + key := args[1].ToStringUnsafe() + hmap, err := fetchMap(hash) + if err != nil { + writer.WriteError(errWrongType) + return + } + value, ok := hmap.Get(key) + if ok { + writer.WriteBulk(value) + } else { + writer.WriteNull() + } +} + +func hdelCommand(writer *RESPWriter, args []RESP) { + hash := args[0] + keys := args[1:] + hmap, err := fetchMap(hash) + if err != nil { + writer.WriteError(err) + return + } + var count int + for _, v := range keys { + if hmap.Remove(v.ToStringUnsafe()) { + count++ + } + } + writer.WriteInteger(count) +} + +func hgetallCommand(writer *RESPWriter, args []RESP) { + hash := args[0] + hmap, err := fetchMap(hash) + if err != nil { + writer.WriteError(err) + return + } + writer.WriteArrayHead(hmap.Len() * 2) + hmap.Scan(func(key string, value []byte) { + writer.WriteBulkString(key) + writer.WriteBulk(value) + }) +} + +func lpushCommand(writer *RESPWriter, args []RESP) { + key := args[0] + ls, err := fetchList(key, true) + if err != nil { + writer.WriteError(err) + return + } + for _, arg := range args[1:] { + ls.LPush(arg.ToStringUnsafe()) + } + writer.WriteInteger(ls.Size()) +} + +func rpushCommand(writer *RESPWriter, args []RESP) { + key := args[0] + ls, err := fetchList(key, true) + if err != nil { + writer.WriteError(err) + return + } + for _, arg := range args[1:] { + ls.RPush(arg.ToStringUnsafe()) + } + writer.WriteInteger(ls.Size()) +} + +func lpopCommand(writer *RESPWriter, args []RESP) { + key := args[0] + ls, err := fetchList(key) + if err != nil { + writer.WriteError(err) + return + } + val, ok := ls.LPop() + if ok { + writer.WriteBulkString(val) + } else { + writer.WriteNull() + } +} + +func rpopCommand(writer *RESPWriter, args []RESP) { + key := args[0] + ls, err := fetchList(key) + if err != nil { + writer.WriteError(err) + return + } + val, ok := ls.RPop() + if ok { + writer.WriteBulkString(val) + } else { + writer.WriteNull() + } +} + +func lrangeCommand(writer *RESPWriter, args []RESP) { + key := args[0] + start, err := args[1].ToInt() + if err != nil { + writer.WriteError(err) + return + } + stop, err := args[2].ToInt() + if err != nil { + writer.WriteError(err) + return + } + ls, err := fetchList(key) + if err != nil { + writer.WriteError(err) + return + } + writer.WriteArrayHead(ls.RangeCount(start, stop)) + ls.Range(start, stop, func(data []byte) { + writer.WriteBulk(data) + }) +} + +func saddCommand(writer *RESPWriter, args []RESP) { + key := args[0] + set, err := fetchSet(key, true) + if err != nil { + writer.WriteError(err) + return + } + var count int + for _, arg := range args[1:] { + if set.Add(arg.ToString()) { + count++ + } + } + writer.WriteInteger(count) +} + +func sremCommand(writer *RESPWriter, args []RESP) { + key := args[0] + set, err := fetchSet(key) + if err != nil { + writer.WriteError(err) + return + } + var count int + for _, arg := range args[1:] { + if set.Remove(arg.ToStringUnsafe()) { + count++ + } + } + writer.WriteInteger(count) +} + +func spopCommand(writer *RESPWriter, args []RESP) { + key := args[0] + set, err := fetchSet(key) + if err != nil { + writer.WriteError(err) + return + } + member, ok := set.Pop() + if ok { + writer.WriteBulkString(member) + } else { + writer.WriteNull() + } +} + +func zaddCommand(writer *RESPWriter, args []RESP) { + key := args[0] + args = args[1:] + + zset, err := fetchZSet(key, true) + if err != nil { + writer.WriteError(err) + return + } + + var count int + for i := 0; i < len(args); i += 2 { + score, err := args[i].ToFloat() + if err != nil { + writer.WriteError(err) + return + } + key := args[i+1].ToString() + if zset.Set(key, score) { + count++ + } + } + writer.WriteInteger(count) +} + +func zrankCommand(writer *RESPWriter, args []RESP) { + key := args[0] + member := args[1].ToStringUnsafe() + + zset, err := fetchZSet(key) + if err != nil { + writer.WriteError(err) + return + } + + rank, _ := zset.Rank(member) + if rank < 0 { + writer.WriteNull() + } else { + writer.WriteInteger(rank) + } +} + +func zremCommand(writer *RESPWriter, args []RESP) { + key := args[0] + zset, err := fetchZSet(key) + if err != nil { + writer.WriteError(err) + return + } + var count int + for _, arg := range args[1:] { + if zset.Remove(arg.ToStringUnsafe()) { + count++ + } + } + writer.WriteInteger(count) +} + +func zrangeCommand(writer *RESPWriter, args []RESP) { + key := args[0] + start, err := args[1].ToInt() + if err != nil { + writer.WriteError(err) + return + } + stop, err := args[2].ToInt() + if err != nil { + writer.WriteError(err) + return + } + zset, err := fetchZSet(key) + if err != nil { + writer.WriteError(err) + return + } + + if stop == -1 { + stop = zset.Len() + } + start = min(start, stop) + + withScores := len(args) == 4 && equalFold(args[3].ToStringUnsafe(), WITH_SCORES) + if withScores { + writer.WriteArrayHead((stop - start) * 2) + zset.Range(start, stop, func(key string, score float64) { + writer.WriteBulkString(key) + writer.WriteFloat(score) + }) + + } else { + writer.WriteArrayHead(stop - start) + zset.Range(start, stop, func(key string, _ float64) { + writer.WriteBulkString(key) + }) + } +} + +func zpopminCommand(writer *RESPWriter, args []RESP) { + key := args[0] + count := 1 + var err error + if len(args) > 1 { + count, err = args[1].ToInt() + if err != nil { + writer.WriteError(err) + return + } + } + + zset, err := fetchZSet(key) + if err != nil { + writer.WriteError(err) + return + } + + size := min(zset.Len(), count) + writer.WriteArrayHead(size * 2) + for range size { + key, score := zset.PopMin() + writer.WriteBulkString(key) + writer.WriteFloat(score) + } +} + +func flushdbCommand(writer *RESPWriter, _ []RESP) { + db.dict = dict.New() + writer.WriteString("OK") +} + +func evalCommand(writer *RESPWriter, args []RESP) { + L := server.lua + script := args[0].ToString() + + if err := L.DoString(script); err != nil { + writer.WriteError(err) + return + } + + var serialize func(isRoot bool, ret lua.LValue) + serialize = func(isRoot bool, ret lua.LValue) { + switch res := ret.(type) { + case lua.LString: + writer.WriteBulkString(res.String()) + + case lua.LNumber: + writer.WriteInteger(int(res)) // convert to integer + + case *lua.LTable: + writer.WriteArrayHead(res.Len()) + res.ForEach(func(index, value lua.LValue) { + serialize(false, value) + }) + + default: + writer.WriteNull() + } + + if isRoot && ret.Type() != lua.LTNil { + L.Pop(1) + } + } + serialize(true, L.Get(-1)) +} + +func fetchMap(key []byte, setnx ...bool) (Map, error) { + return fetch(key, func() Map { return hash.NewZipMap() }, setnx...) +} + +func fetchList(key []byte, setnx ...bool) (List, error) { + return fetch(key, func() List { return list.New() }, setnx...) +} + +func fetchSet(key []byte, setnx ...bool) (Set, error) { + return fetch(key, func() Set { return hash.NewZipSet() }, setnx...) +} + +func fetchZSet(key []byte, setnx ...bool) (ZSet, error) { + return fetch(key, func() ZSet { return zset.NewZSet() }, setnx...) +} + +func fetch[T any](key []byte, new func() T, setnx ...bool) (T, error) { + object, ttl := db.dict.Get(b2s(key)) + + if ttl != dict.KEY_NOT_EXIST { + v, ok := object.(T) + if !ok { + return v, errWrongType + } + + // conversion zipped structure + if len(setnx) > 0 && setnx[0] { + switch data := object.(type) { + case *hash.ZipMap: + if data.Len() < 256 { + break + } + db.dict.Set(string(key), data.ToMap()) + + case *hash.ZipSet: + if data.Len() < 512 { + break + } + db.dict.Set(string(key), data.ToSet()) + } + } + return v, nil + } + + v := new() + if len(setnx) > 0 && setnx[0] { + db.dict.Set(string(key), v) + } + + return v, nil +} diff --git a/command_test.go b/command_test.go index 8096ad6..72ec448 100644 --- a/command_test.go +++ b/command_test.go @@ -20,13 +20,12 @@ func startup() { AppendOnly: true, AppendFileName: "test.aof", } - os.Remove(config.AppendFileName) + _ = os.Remove(config.AppendFileName) config4Server(config) printBanner(config) - server.aeLoop.AddRead(server.fd, AcceptHandler, nil) + RegisterAeLoop(&server) // custom server.aeLoop.AddTimeEvent(AE_ONCE, 300, func(loop *AeLoop, id int, extra interface{}) {}, nil) - server.aeLoop.AddTimeEvent(AE_NORMAL, 1000, CronSyncAOF, nil) server.aeLoop.AeMain() } @@ -79,20 +78,20 @@ func testCommand(t *testing.T, rdb *redis.Client, sleepFn func(time.Duration)) { // setex { - res, _ := rdb.Set(ctx, "foo", "bar", time.Second).Result() + res, _ = rdb.Set(ctx, "foo", "bar", time.Second).Result() assert.Equal(res, "OK") res, _ = rdb.Get(ctx, "foo").Result() assert.Equal(res, "bar") - sleepFn(time.Second + time.Millisecond) + sleepFn(time.Second + 10*time.Millisecond) _, err := rdb.Get(ctx, "foo").Result() assert.Equal(err, redis.Nil) } // setpx { - res, _ := rdb.Set(ctx, "foo", "bar", time.Millisecond*100).Result() + res, _ = rdb.Set(ctx, "foo", "bar", time.Millisecond*100).Result() assert.Equal(res, "OK") res, _ = rdb.Get(ctx, "foo").Result() diff --git a/go.mod b/go.mod index f079ec5..3a08640 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/stretchr/testify v1.9.0 github.com/tidwall/mmap v0.3.0 github.com/yuin/gopher-lua v1.1.1 - golang.org/x/sys v0.25.0 + golang.org/x/sys v0.26.0 ) require ( diff --git a/go.sum b/go.sum index 2d6ad42..bb218cd 100644 --- a/go.sum +++ b/go.sum @@ -64,8 +64,8 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= -golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gonum.org/v1/gonum v0.0.0-20181121035319-3f7ecaa7e8ca h1:PupagGYwj8+I4ubCxcmcBRk3VlUWtTg5huQpZR9flmE= gonum.org/v1/gonum v0.0.0-20181121035319-3f7ecaa7e8ca/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= diff --git a/internal/list/bench_test.go b/internal/list/bench_test.go index dfe00dd..e6cbb34 100644 --- a/internal/list/bench_test.go +++ b/internal/list/bench_test.go @@ -39,3 +39,22 @@ func BenchmarkList(b *testing.B) { } }) } + +func BenchmarkListPack(b *testing.B) { + b.Run("next", func(b *testing.B) { + lp := genListPack(0, 100) + b.ResetTimer() + for i := 0; i < b.N; i++ { + for it := lp.Iterator(); !it.IsLast(); it.Next() { + } + } + }) + b.Run("prev", func(b *testing.B) { + lp := genListPack(0, 100) + b.ResetTimer() + for i := 0; i < b.N; i++ { + for it := lp.Iterator().SeekLast(); !it.IsFirst(); it.Prev() { + } + } + }) +} diff --git a/internal/list/list_test.go b/internal/list/list_test.go index d5fa63e..018021f 100644 --- a/internal/list/list_test.go +++ b/internal/list/list_test.go @@ -16,6 +16,14 @@ func genList(start, stop int) *QuickList { return lp } +func genListPack(start, stop int) *ListPack { + lp := NewListPack() + for i := start; i < stop; i++ { + lp.RPush(genKey(i)) + } + return lp +} + func list2slice(ls *QuickList) (res []string) { ls.Range(0, ls.Size(), func(data []byte) { res = append(res, string(data)) diff --git a/internal/list/utils.go b/internal/list/utils.go index fb72e3b..457253b 100644 --- a/internal/list/utils.go +++ b/internal/list/utils.go @@ -20,9 +20,7 @@ func appendUvarint(b []byte, n int, reverse bool) []byte { } // uvarintReverse is the reverse version from binary.Uvarint. -func uvarintReverse(buf []byte) (uint64, int) { - var x uint64 - var s uint +func uvarintReverse(buf []byte) (x uint64, s int) { for i := range buf { b := buf[len(buf)-1-i] if b < 0x80 { diff --git a/main.go b/main.go index 2f457b0..1d290af 100644 --- a/main.go +++ b/main.go @@ -1,81 +1,85 @@ -package main - -import ( - "flag" - "net/http" - _ "net/http/pprof" - "os" - "runtime" - "strconv" - "time" - - "github.com/rs/zerolog" -) - -var ( - log = initLogger() - buildTime string -) - -func initLogger() zerolog.Logger { - return zerolog. - New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.DateTime}). - Level(zerolog.TraceLevel). - With(). - Timestamp(). - Logger() -} - -func config4Server(config *Config) { - if err := initServer(config); err != nil { - log.Fatal().Msgf("init server error: %v", err) - } - if err := InitDB(config); err != nil { - log.Fatal().Msgf("init db error: %v", err) - } -} - -func printBanner(config *Config) { - log.Printf(` -________ _____ -___ __ \_______ /_____________ ___ Rotom %d bit (%s/%s) -__ /_/ / __ \ __/ __ \_ __ '__ \ Port: %d, Pid: %d -_ _, _// /_/ / /_ / /_/ / / / / / / Build: %s -/_/ |_| \____/\__/ \____//_/ /_/ /_/ - `, - strconv.IntSize, runtime.GOARCH, runtime.GOOS, - config.Port, os.Getpid(), - buildTime) -} - -func main() { - var path string - var debug bool - - flag.StringVar(&path, "config", "config.json", "default config file path.") - flag.BoolVar(&debug, "debug", false, "run with debug mode.") - flag.Parse() - - config, err := LoadConfig(path) - if err != nil { - log.Fatal().Msgf("load config error: %v", err) - } - printBanner(config) - - if debug { - go http.ListenAndServe(":6060", nil) - } - - log.Info().Str("config", path).Msg("read config file") - config4Server(config) - - log.Info().Msg("rotom server is ready to accept.") - - // register main aeLoop event - server.aeLoop.AddRead(server.fd, AcceptHandler, nil) - server.aeLoop.AddTimeEvent(AE_NORMAL, 100, CronEvictExpired, nil) - if server.config.AppendOnly { - server.aeLoop.AddTimeEvent(AE_NORMAL, 1000, CronSyncAOF, nil) - } - server.aeLoop.AeMain() -} +package main + +import ( + "flag" + "net/http" + _ "net/http/pprof" + "os" + "runtime" + "strconv" + "time" + + "github.com/rs/zerolog" +) + +var ( + log = initLogger() + buildTime string +) + +func initLogger() zerolog.Logger { + return zerolog. + New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.DateTime}). + Level(zerolog.TraceLevel). + With(). + Timestamp(). + Logger() +} + +func config4Server(config *Config) { + if err := initServer(config); err != nil { + log.Fatal().Msgf("init server error: %v", err) + } + if err := InitDB(config); err != nil { + log.Fatal().Msgf("init db error: %v", err) + } +} + +func printBanner(config *Config) { + log.Printf(` +________ _____ +___ __ \_______ /_____________ ___ Rotom %d bit (%s/%s) +__ /_/ / __ \ __/ __ \_ __ '__ \ Port: %d, Pid: %d +_ _, _// /_/ / /_ / /_/ / / / / / / Build: %s +/_/ |_| \____/\__/ \____//_/ /_/ /_/ + `, + strconv.IntSize, runtime.GOARCH, runtime.GOOS, + config.Port, os.Getpid(), + buildTime) +} + +// RegisterAeLoop register main aeLoop event. +func RegisterAeLoop(server *Server) { + server.aeLoop.AddRead(server.fd, AcceptHandler, nil) + server.aeLoop.AddTimeEvent(AE_NORMAL, 100, CronEvictExpired, nil) + if server.config.AppendOnly { + server.aeLoop.AddTimeEvent(AE_NORMAL, 1000, CronSyncAOF, nil) + } +} + +func main() { + var path string + var debug bool + + flag.StringVar(&path, "config", "config.json", "default config file path.") + flag.BoolVar(&debug, "debug", false, "run with debug mode.") + flag.Parse() + + config, err := LoadConfig(path) + if err != nil { + log.Fatal().Msgf("load config error: %v", err) + } + printBanner(config) + + if debug { + go http.ListenAndServe(":6060", nil) + } + + log.Info().Str("config", path).Msg("read config file") + config4Server(config) + + log.Info().Msg("rotom server is ready to accept.") + + RegisterAeLoop(&server) + server.aeLoop.AeMain() +} diff --git a/rotom.go b/rotom.go index 7259ff5..bcf8f9c 100644 --- a/rotom.go +++ b/rotom.go @@ -100,7 +100,7 @@ func AcceptHandler(loop *AeLoop, fd int, _ interface{}) { loop.AddRead(cfd, ReadQueryFromClient, client) } -func ReadQueryFromClient(loop *AeLoop, fd int, extra interface{}) { +func ReadQueryFromClient(_ *AeLoop, fd int, extra interface{}) { client := extra.(*Client) readSize := 0 @@ -211,7 +211,7 @@ func SendReplyToClient(loop *AeLoop, fd int, extra interface{}) { func initServer(config *Config) (err error) { server.config = config server.clients = make(map[int]*Client) - // init aeloop + // init aeLoop server.aeLoop, err = AeLoopCreate() if err != nil { return err @@ -219,7 +219,7 @@ func initServer(config *Config) (err error) { // init tcp server server.fd, err = TcpServer(config.Port) if err != nil { - Close(server.fd) + _ = Close(server.fd) return err } // init lua state