Skip to content

Commit

Permalink
modify client to interface type (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick0308 authored Jul 5, 2022
1 parent 91ecf7b commit 4f3dabe
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 46 deletions.
82 changes: 52 additions & 30 deletions go/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,37 @@ var (
errConnClosed = errors.New("client conn closed")
)

// client is an socket client interface
type Client interface {
// Dial using to dial with server
Dial(ctx context.Context, u string, handshake *protocol.Handshake, opts ...DialOption) error
// AuthInfo return authorization information
AuthInfo() *control.AuthResponse
// Do will do request to server
Do(ctx context.Context, req *Request, opts ...RequestOption) (*protocol.Packet, error)
// Subscribe using to register handle of push data
Subscribe(cmd uint32, sub func(*protocol.Packet))
// AfterReconnected using to handle client after reconnected
AfterReconnected(fn func())
// OnPing using to custom handle ping packet
OnPing(fn func(*protocol.Packet))
// OnPong using to custom handle pong packet
OnPong(fn func(*protocol.Packet))
// OnClose using to handle client close
OnClose(fn func(err error))
// Close used to close conn between server
Close(err error) error
}

// Request represents an socket request to server
type Request struct {
Cmd uint32
Body proto.Message
}

// New returns a new client instance
func New(opts ...ClientOption) *Client {
c := &Client{
func New(opts ...ClientOption) Client {
c := &client{
closeCh: make(chan struct{}),
subs: make(map[uint32][]func(*protocol.Packet)),
recvs: make(map[uint32]chan *protocol.Packet),
Expand All @@ -49,8 +71,8 @@ func New(opts ...ClientOption) *Client {
return c
}

// Client is an socket client
type Client struct {
// client is an socket client
type client struct {
sync.RWMutex

Context context.Context
Expand Down Expand Up @@ -88,7 +110,7 @@ type Client struct {
}

// Dial using to dial with server
func (c *Client) Dial(ctx context.Context, u string, handshake *protocol.Handshake, opts ...DialOption) error {
func (c *client) Dial(ctx context.Context, u string, handshake *protocol.Handshake, opts ...DialOption) error {
c.handshake = handshake

uri, err := url.Parse(u)
Expand Down Expand Up @@ -123,19 +145,19 @@ func (c *Client) Dial(ctx context.Context, u string, handshake *protocol.Handsha
return c.auth()
}

func (c *Client) AuthInfo() *control.AuthResponse {
func (c *client) AuthInfo() *control.AuthResponse {
return c.authInfo
}

func (c *Client) dial(ctx context.Context, dialer DialConnFunc) (err error) {
if c.conn, err = dialer(ctx, c, c.addr, c.handshake, c.dialOptions); err == nil {
func (c *client) dial(ctx context.Context, dialer DialConnFunc) (err error) {
if c.conn, err = dialer(ctx, c.Logger, c.addr, c.handshake, c.dialOptions); err == nil {
c.conn.OnPacket(c.onPacket)
}

return
}

func (c *Client) auth() error {
func (c *client) auth() error {
if c.dialOptions.AuthTokenGetter == nil {
return nil
}
Expand Down Expand Up @@ -163,7 +185,7 @@ func (c *Client) auth() error {
return nil
}

func (c *Client) reconnecting() {
func (c *client) reconnecting() {
c.Lock()
if c.doReconnectting {
return
Expand Down Expand Up @@ -213,7 +235,7 @@ func (c *Client) reconnecting() {
c.Unlock()
}

func (c *Client) reconnect() error {
func (c *client) reconnect() error {
if c.dialOptions.MaxReconnect > 0 {
if c.reconnectCount >= c.dialOptions.MaxReconnect {
return ErrHitMaxReconnect
Expand Down Expand Up @@ -252,7 +274,7 @@ func (c *Client) reconnect() error {
return c.reconnectDial()
}

func (c *Client) reconnectDial() error {
func (c *client) reconnectDial() error {
res, err := c.Do(c.Context, &Request{Cmd: uint32(control.Command_CMD_RECONNECT), Body: &control.ReconnectRequest{
SessionId: c.authInfo.SessionId,
}}, RequestTimeout(c.dialOptions.AuthTimeout))
Expand All @@ -275,7 +297,7 @@ func (c *Client) reconnectDial() error {
return nil
}

func (c *Client) isAuthExpired() bool {
func (c *client) isAuthExpired() bool {
if c.authInfo == nil {
return true
}
Expand All @@ -286,7 +308,7 @@ func (c *Client) isAuthExpired() bool {
}

// Do will do request to server
func (c *Client) Do(ctx context.Context, req *Request, opts ...RequestOption) (res *protocol.Packet, err error) {
func (c *client) Do(ctx context.Context, req *Request, opts ...RequestOption) (res *protocol.Packet, err error) {
rp, e := protocol.NewRequest(c.conn.Context(), req.Cmd, req.Body)

if e != nil {
Expand Down Expand Up @@ -317,7 +339,7 @@ func (c *Client) Do(ctx context.Context, req *Request, opts ...RequestOption) (r

// Subscribe using to register handle of push data
// concurrency unsafe, please sub at first time
func (c *Client) Subscribe(cmd uint32, sub func(*protocol.Packet)) {
func (c *client) Subscribe(cmd uint32, sub func(*protocol.Packet)) {
var (
subs []func(*protocol.Packet)
ok bool
Expand All @@ -333,27 +355,27 @@ func (c *Client) Subscribe(cmd uint32, sub func(*protocol.Packet)) {
}

// OnPing using to custom handle ping packet
func (c *Client) OnPing(fn func(*protocol.Packet)) {
func (c *client) OnPing(fn func(*protocol.Packet)) {
c.onPing = fn
}

// OnPong using to custom handle pong packet
func (c *Client) OnPong(fn func(*protocol.Packet)) {
func (c *client) OnPong(fn func(*protocol.Packet)) {
c.onPong = fn
}

// OnClose using to handle client close
func (c *Client) OnClose(fn func(err error)) {
func (c *client) OnClose(fn func(err error)) {
c.onClose = fn
}

// AfterReconnected using to handle client after reconnected
func (c *Client) AfterReconnected(fn func()) {
func (c *client) AfterReconnected(fn func()) {
c.afterReconnected = fn
}

// Close used to close conn between server
func (c *Client) Close(err error) error {
func (c *client) Close(err error) error {
c.Logger.Info("close client")

c.conn.Close(errors.New("close by client"))
Expand All @@ -367,7 +389,7 @@ func (c *Client) Close(err error) error {
return nil
}

func (c *Client) closeByServer(packet *protocol.Packet) {
func (c *client) closeByServer(packet *protocol.Packet) {
var reason control.Close

if err := packet.Unmarshal(&reason); err != nil {
Expand All @@ -381,7 +403,7 @@ func (c *Client) closeByServer(packet *protocol.Packet) {
c.reconnecting()
}

func (c *Client) keepalive() {
func (c *client) keepalive() {
t := time.NewTicker(c.dialOptions.Keepalive)

now := time.Now()
Expand Down Expand Up @@ -439,7 +461,7 @@ func (c *Client) keepalive() {
}
}

func (c *Client) onPacket(packet *protocol.Packet, err error) {
func (c *client) onPacket(packet *protocol.Packet, err error) {
if err != nil {
c.Logger.Errorf("conn receive packet error: %v", err)
c.reconnecting()
Expand All @@ -466,7 +488,7 @@ func (c *Client) onPacket(packet *protocol.Packet, err error) {
c.Logger.Warnf("client did't support request now, cmd: %d", packet.Metadata.CmdCode)
}

func (c *Client) handlePush(packet *protocol.Packet) {
func (c *client) handlePush(packet *protocol.Packet) {
subs, ok := c.subs[packet.CMD()]

if !ok || len(subs) == 0 {
Expand All @@ -478,7 +500,7 @@ func (c *Client) handlePush(packet *protocol.Packet) {
}
}

func (c *Client) handleControl(packet *protocol.Packet) {
func (c *client) handleControl(packet *protocol.Packet) {
if packet.IsPing() {
c.handlePing(packet)
return
Expand All @@ -499,7 +521,7 @@ func (c *Client) handleControl(packet *protocol.Packet) {
}
}

func (c *Client) handleResponse(packet *protocol.Packet) {
func (c *client) handleResponse(packet *protocol.Packet) {
c.recvsMu.RLock()
defer c.recvsMu.RUnlock()

Expand All @@ -514,7 +536,7 @@ func (c *Client) handleResponse(packet *protocol.Packet) {
c.Logger.Warnf("no receiver for req %d", packet.Metadata.RequestId)
}

func (c *Client) handlePing(packet *protocol.Packet) {
func (c *client) handlePing(packet *protocol.Packet) {
if c.onPing != nil {
c.onPing(packet)
}
Expand All @@ -530,7 +552,7 @@ func (c *Client) handlePing(packet *protocol.Packet) {
}
}

func (c *Client) handlePong(packet *protocol.Packet) {
func (c *client) handlePong(packet *protocol.Packet) {
if c.onPong != nil {
c.onPong(packet)
}
Expand All @@ -540,7 +562,7 @@ func (c *Client) handlePong(packet *protocol.Packet) {
}
}

func (c *Client) recv(ctx context.Context, rid uint32) (res *protocol.Packet, err error) {
func (c *client) recv(ctx context.Context, rid uint32) (res *protocol.Packet, err error) {
ch := make(chan *protocol.Packet, 1)

defer func() {
Expand All @@ -564,6 +586,6 @@ func (c *Client) recv(ctx context.Context, rid uint32) (res *protocol.Packet, er
return
}

func (c *Client) write(p *protocol.Packet) error {
func (c *client) write(p *protocol.Packet) error {
return c.conn.Write(p, protocol.GzipSize(c.dialOptions.MinGzipSize))
}
2 changes: 1 addition & 1 deletion go/client/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func GetDialer(scheme string) (DialConnFunc, bool) {
return f, ok
}

type DialConnFunc func(ctx context.Context, cli *Client, uri *url.URL, handshake *protocol.Handshake, opts *DialOptions) (ClientConn, error)
type DialConnFunc func(ctx context.Context, logger protocol.Logger, uri *url.URL, handshake *protocol.Handshake, opts *DialOptions) (ClientConn, error)

// ClientConn is socket conn abstract
type ClientConn interface {
Expand Down
23 changes: 15 additions & 8 deletions go/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func init() {
RegisterDialer("mock", DialConnFunc(mockDialer))
}

func mockDialer(ctx context.Context, _ *Client, uri *url.URL, handshake *protocol.Handshake, opts *DialOptions) (ClientConn, error) {
func mockDialer(ctx context.Context, _ protocol.Logger, uri *url.URL, handshake *protocol.Handshake, opts *DialOptions) (ClientConn, error) {
qctx := protocol.NewContext(ctx, protocol.ClientSide)

qctx.Platform = handshake.Platform
Expand Down Expand Up @@ -163,10 +163,12 @@ func (c *mockConn) Write(p *protocol.Packet, opts ...protocol.PackOption) error

return nil
}

func (c *mockConn) Context() *protocol.Context {
return c.ctx
}
func newClientAndDial(opts ...DialOption) (*Client, error) {

func newClientAndDial(opts ...DialOption) (Client, error) {
c := New()

err := c.Dial(context.Background(), "mock://127.0.0.1", &protocol.Handshake{
Expand All @@ -184,11 +186,13 @@ func TestClientDialer(t *testing.T) {
}

func TestClientAuth(t *testing.T) {
cli, _ := newClientAndDial()
c, _ := newClientAndDial()
cli := c.(*client)
mc := cli.conn.(*mockConn)

cli.dialOptions.AuthToken = "test auth"

cli.dialOptions.AuthTokenGetter = func() (string, error) {
return "test auth", nil
}
info := &control.AuthResponse{
SessionId: "session",
Expires: time.Now().Add(time.Minute*2).UnixNano() / int64(time.Millisecond),
Expand All @@ -204,7 +208,8 @@ func TestClientAuth(t *testing.T) {
}

func TestClientReconnect(t *testing.T) {
cli, _ := newClientAndDial()
c, _ := newClientAndDial()
cli := c.(*client)

cli.dialOptions.MaxReconnect = 1

Expand All @@ -224,7 +229,8 @@ func TestClientReconnect(t *testing.T) {
}

func TestClientSubscribe(t *testing.T) {
cli, _ := newClientAndDial()
c, _ := newClientAndDial()
cli := c.(*client)
mc := cli.conn.(*mockConn)

testCmd := uint32(111)
Expand All @@ -249,7 +255,8 @@ func TestClientSubscribe(t *testing.T) {
}

func TestClientOnPingPong(t *testing.T) {
cli, _ := newClientAndDial(Keepalive(time.Second * 3))
c, _ := newClientAndDial(Keepalive(time.Second * 3))
cli := c.(*client)
mc := cli.conn.(*mockConn)

waitCh := make(chan struct{}, 1)
Expand Down
6 changes: 3 additions & 3 deletions go/client/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,18 +176,18 @@ func RequestTimeout(d time.Duration) RequestOption {
}

// ClientOption is func to set Client Context and Logger
type ClientOption func(*Client)
type ClientOption func(*client)

// WithContext set context of client
func WithContext(ctx context.Context) ClientOption {
return func(c *Client) {
return func(c *client) {
c.Context = ctx
}
}

// WithLogger set Logger of client
func WithLogger(l protocol.Logger) ClientOption {
return func(c *Client) {
return func(c *client) {
c.Logger = l
}
}
4 changes: 2 additions & 2 deletions go/client/tcp_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func parseHostAndPort(p string) (host, port string) {
return parts[0], parts[1]
}

func dialTCPConn(ctx context.Context, cli *Client, uri *url.URL, handshake *protocol.Handshake, o *DialOptions) (ClientConn, error) {
func dialTCPConn(ctx context.Context, logger protocol.Logger, uri *url.URL, handshake *protocol.Handshake, o *DialOptions) (ClientConn, error) {

ver := handshake.Version
p, err := protocol.GetProtocol(ver)
Expand All @@ -49,7 +49,7 @@ func dialTCPConn(ctx context.Context, cli *Client, uri *url.URL, handshake *prot
qctx.Version = handshake.Version

c := &tcpConn{
logger: cli.Logger,
logger: logger,
qctx: qctx,
p: p,
conn: conn,
Expand Down
4 changes: 2 additions & 2 deletions go/client/ws_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func init() {

var ErrInvalidMessage = errors.New("invlaid websocket message")

func dialWSConn(ctx context.Context, cli *Client, uri *url.URL, handshake *protocol.Handshake, o *DialOptions) (ClientConn, error) {
func dialWSConn(ctx context.Context, logger protocol.Logger, uri *url.URL, handshake *protocol.Handshake, o *DialOptions) (ClientConn, error) {
ver := handshake.Version
p, err := protocol.GetProtocol(ver)

Expand All @@ -51,7 +51,7 @@ func dialWSConn(ctx context.Context, cli *Client, uri *url.URL, handshake *proto
qctx.Version = handshake.Version

c := &wsConn{
logger: cli.Logger,
logger: logger,
qctx: qctx,
p: p,
conn: conn,
Expand Down

0 comments on commit 4f3dabe

Please sign in to comment.