From a3e1797dc489acecd4ded0aacde12ffd48a89dba Mon Sep 17 00:00:00 2001 From: yowcow Date: Mon, 28 May 2018 14:00:50 +0900 Subject: [PATCH] server to only manage tcp/unix socket --- main.go | 16 ++++- server/server.go | 35 ++++------ server/server_test.go | 150 ++++++++++-------------------------------- 3 files changed, 59 insertions(+), 142 deletions(-) diff --git a/main.go b/main.go index c267ec6..140b2b9 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "flag" "fmt" "log" + "net" "os" "github.com/yowcow/goromdb/handler" @@ -90,8 +91,19 @@ func main() { os.Getpid(), addr, protoBackend, handlerBackend, storageBackend, file, ) - svr := server.New("tcp", addr, proto, h, logger) - err = svr.Start() + svr := server.New("tcp", addr, logger) + err = svr.Start(server.OnReadCallbackFunc(func(conn net.Conn, line []byte, logger *log.Logger) { + if keys, err := proto.Parse(line); err != nil { + logger.Printf("server failed parsing a line: %s", err) + } else { + for _, k := range keys { + if v, _ := h.Get(k); v != nil { + proto.Reply(conn, k, v) + } + } + } + proto.Finish(conn) + })) if err != nil { logger.Printf("failed booting goromdb: %s", err.Error()) os.Exit(1) diff --git a/server/server.go b/server/server.go index f3a44c3..b38a123 100644 --- a/server/server.go +++ b/server/server.go @@ -5,27 +5,24 @@ import ( "io" "log" "net" - - "github.com/yowcow/goromdb/handler" - "github.com/yowcow/goromdb/protocol" ) +type OnReadCallbackFunc func(net.Conn, []byte, *log.Logger) + // Server represents a server type Server struct { - network string - addr string - protocol protocol.Protocol - handler handler.Handler - logger *log.Logger + network string + addr string + logger *log.Logger } // New creates a new server -func New(network, addr string, p protocol.Protocol, h handler.Handler, logger *log.Logger) *Server { - return &Server{network, addr, p, h, logger} +func New(network, addr string, logger *log.Logger) *Server { + return &Server{network, addr, logger} } // Start starts a server and spawns a goroutine when a new connection is accepted -func (s Server) Start() error { +func (s Server) Start(callback OnReadCallbackFunc) error { ln, err := net.Listen(s.network, s.addr) if err != nil { return err @@ -35,13 +32,13 @@ func (s Server) Start() error { if err != nil { s.logger.Printf("server failed accepting a conn: %s", err.Error()) } else { - go s.HandleConn(conn) + go s.HandleConn(conn, callback) } } } // HandleConn handles a net.Conn -func (s Server) HandleConn(conn net.Conn) { +func (s Server) HandleConn(conn net.Conn, callback OnReadCallbackFunc) { defer conn.Close() r := bufio.NewReader(conn) for { @@ -53,15 +50,7 @@ func (s Server) HandleConn(conn net.Conn) { s.logger.Printf("server failed reading a line: %s", err) return } - if keys, err := s.protocol.Parse(line); err != nil { - s.logger.Printf("server failed parsing a line: %s", err) - } else { - for _, k := range keys { - if v, _ := s.handler.Get(k); v != nil { - s.protocol.Reply(conn, k, v) - } - } - } - s.protocol.Finish(conn) + + callback(conn, line, s.logger) } } diff --git a/server/server_test.go b/server/server_test.go index ad85dc0..0d7cbc1 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1,10 +1,7 @@ package server import ( - "bufio" "bytes" - "fmt" - "io" "log" "net" "os" @@ -12,144 +9,63 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/yowcow/goromdb/handler/simplehandler" - "github.com/yowcow/goromdb/protocol" - "github.com/yowcow/goromdb/storage" "github.com/yowcow/goromdb/testutil" ) -type TestKeywords map[string][][]byte - -type TestProtocol struct { - keywords TestKeywords -} - -func createTestProtocol() protocol.Protocol { - keywords := TestKeywords{ - "hoge": {[]byte("foo"), []byte("bar")}, - } - return &TestProtocol{keywords} -} - -func (p TestProtocol) Parse(line []byte) ([][]byte, error) { - if words, ok := p.keywords[string(line)]; ok { - return words, nil - } - return [][]byte{}, fmt.Errorf("invalid command") -} - -func (p TestProtocol) Reply(w io.Writer, key, value []byte) { - w.Write(key) - w.Write([]byte(" ")) - w.Write(value) - w.Write([]byte("\r\n")) -} - -func (p TestProtocol) Finish(w io.Writer) { - w.Write([]byte("BYE\r\n")) -} - -type TestData map[string]string - -type TestStorage struct { - data TestData - logger *log.Logger -} - -func createTestStorage(logger *log.Logger) storage.Storage { - data := TestData{ - "foo": "foo!", - "bar": "bar!!", - } - return &TestStorage{data, logger} -} - -func (s TestStorage) Load(file string) error { - return nil -} - -func (s TestStorage) Get(key []byte) ([]byte, error) { - if v, ok := s.data[string(key)]; ok { - return []byte(v), nil - } - return nil, storage.KeyNotFoundError(key) -} - func TestHandleConn(t *testing.T) { dir := testutil.CreateTmpDir() defer os.RemoveAll(dir) logbuf := new(bytes.Buffer) logger := log.New(logbuf, "", 0) - p := createTestProtocol() - stg := createTestStorage(logger) - h := simplehandler.New(stg, logger) sock := filepath.Join(dir, "test.sock") - svr := New("unix", sock, p, h, logger) - - done := make(chan bool) - ln, err := net.Listen("unix", sock) - if err != nil { - panic(err) - } - go func() { - defer close(done) - for { - conn, err := ln.Accept() - if err != nil { - break - } - svr.HandleConn(conn) - } - }() + svr := New("unix", sock, logger) type Case struct { - input string - expected []string - subtest string + subtest string + input []byte + expectedLine []byte } cases := []Case{ { - "hoge\r\n", - []string{ - "foo foo!", - "bar bar!!", - "BYE", - }, - "hoge returns 3 lines of message", - }, - { - "fuga\r\n", - []string{ - "BYE", - }, - "fuga returns 1 line of message", + "a line that end with \\r\\n", + []byte("hello world\r\n"), + []byte("hello world"), }, } for _, c := range cases { - t.Run(c.subtest, func(t *testing.T) { - conn, err := net.Dial("unix", sock) - if err != nil { - panic(err) + done := make(chan bool) + + ln, err := net.Listen("unix", sock) + if err != nil { + panic(err) + } + go func() { + defer close(done) + for { + conn, err := ln.Accept() + if err != nil { + return + } + svr.HandleConn(conn, OnReadCallbackFunc(func(conn net.Conn, line []byte, logger *log.Logger) { + assert.Equal(t, c.expectedLine, line) + })) } - defer conn.Close() + }() - r := bufio.NewReader(conn) - _, err = conn.Write([]byte(c.input)) + conn, err := net.Dial("unix", sock) + if err != nil { + panic(err) + } - assert.Nil(t, err) + _, err = conn.Write(c.input) - for _, row := range c.expected { - actual, _, err := r.ReadLine() + conn.Close() + ln.Close() + <-done - assert.Nil(t, err) - assert.Equal(t, row, string(actual)) - } - }) + assert.Nil(t, err) } - - ln.Close() - <-done }