Skip to content

Commit

Permalink
increase compatibility of PacketHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
fish-tennis committed Jan 10, 2024
1 parent d0679c5 commit 0054fd9
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 35 deletions.
3 changes: 3 additions & 0 deletions codec_protobuf.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
// proto.Message ctor func
type ProtoMessageCreator func() proto.Message

// Packet ctor func
type PacketCreator func() Packet

type ProtoRegister interface {
Register(command PacketCommand, protoMessage proto.Message)
}
Expand Down
2 changes: 1 addition & 1 deletion codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func TestHandler(t *testing.T) {
defaultHandler := NewDefaultConnectionHandler(nil)
defaultHandler.GetCodec()
defaultHandler.CreateHeartBeatPacket(nil)
defaultHandler.SetUnRegisterHandler(func(connection Connection, packet *ProtoPacket) {
defaultHandler.SetUnRegisterHandler(func(connection Connection, packet Packet) {

})
defaultHandler.OnRecvPacket(nil, NewProtoPacket(123, nil))
Expand Down
17 changes: 8 additions & 9 deletions echo_proto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"github.com/fish-tennis/gnet/example/pb"
"google.golang.org/protobuf/proto"
"testing"
"time"
)
Expand Down Expand Up @@ -46,7 +45,7 @@ func TestEchoProto(t *testing.T) {
// 注册服务器的消息回调
serverHandler.Register(PacketCommand(pb.CmdTest_Cmd_HeartBeat), onHeartBeatReq, new(pb.HeartBeatReq))
serverHandler.Register(PacketCommand(pb.CmdTest_Cmd_TestMessage), onTestMessageServer, new(pb.TestMessage))
serverHandler.SetUnRegisterHandler(func(connection Connection, packet *ProtoPacket) {
serverHandler.SetUnRegisterHandler(func(connection Connection, packet Packet) {
logger.Warn("%v", packet)
})
serverHandler.GetPacketHandler(PacketCommand(pb.CmdTest_Cmd_TestMessage))
Expand All @@ -61,8 +60,8 @@ func TestEchoProto(t *testing.T) {
DefaultConnectionHandler: *NewDefaultConnectionHandler(clientCodec),
}
// 客户端作为connector,需要设置心跳包
clientHandler.RegisterHeartBeat(PacketCommand(pb.CmdTest_Cmd_HeartBeat), func() proto.Message {
return &pb.HeartBeatReq{}
clientHandler.RegisterHeartBeat(func() Packet {
return NewProtoPacket(PacketCommand(pb.CmdTest_Cmd_HeartBeat),&pb.HeartBeatReq{})
})
// 注册客户端的消息回调
clientHandler.Register(PacketCommand(pb.CmdTest_Cmd_HeartBeat), clientHandler.onHeartBeatRes, new(pb.HeartBeatRes))
Expand Down Expand Up @@ -116,7 +115,7 @@ func echoProtoOnConnected(connection Connection, success bool) {
}

// 服务器收到客户端的心跳包
func onHeartBeatReq(connection Connection, packet *ProtoPacket) {
func onHeartBeatReq(connection Connection, packet Packet) {
req := packet.Message().(*pb.HeartBeatReq)
logger.Debug(fmt.Sprintf("Server onHeartBeatReq: %v", req))
connection.Send(PacketCommand(pb.CmdTest_Cmd_HeartBeat), &pb.HeartBeatRes{
Expand All @@ -126,7 +125,7 @@ func onHeartBeatReq(connection Connection, packet *ProtoPacket) {
}

// 服务器收到客户端的TestMessage
func onTestMessageServer(connection Connection, packet *ProtoPacket) {
func onTestMessageServer(connection Connection, packet Packet) {
req := packet.Message().(*pb.TestMessage)
logger.Debug(fmt.Sprintf("Server onTestMessage: %v", req))
}
Expand All @@ -138,12 +137,12 @@ type echoProtoClientHandler struct {
}

// 收到心跳包回复
func (e *echoProtoClientHandler) onHeartBeatRes(connection Connection, packet *ProtoPacket) {
func (e *echoProtoClientHandler) onHeartBeatRes(connection Connection, packet Packet) {
res := packet.Message().(*pb.HeartBeatRes)
logger.Debug(fmt.Sprintf("client onHeartBeatRes: %v", res))
}

func (e *echoProtoClientHandler) onTestMessage(connection Connection, packet *ProtoPacket) {
func (e *echoProtoClientHandler) onTestMessage(connection Connection, packet Packet) {
res := packet.Message().(*pb.TestMessage)
logger.Debug(fmt.Sprintf("client onTestMessage: %v", res))
e.echoCount++
Expand All @@ -155,7 +154,7 @@ func (e *echoProtoClientHandler) onTestMessage(connection Connection, packet *Pr
}

// 测试没有注册proto.Message的消息
func (e *echoProtoClientHandler) onTestDataMessage(connection Connection, packet *ProtoPacket) {
func (e *echoProtoClientHandler) onTestDataMessage(connection Connection, packet Packet) {
logger.Debug(fmt.Sprintf("client onTestDataMessage: %v", string(packet.GetStreamData())))
e.echoCount++
connection.Send(PacketCommand(pb.CmdTest_Cmd_TestMessage),
Expand Down
36 changes: 18 additions & 18 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ type PacketHandlerRegister interface {
Register(packetCommand PacketCommand, handler PacketHandler, protoMessage proto.Message)
}

// handler for ProtoPacket
type PacketHandler func(connection Connection, packet *ProtoPacket)
// handler for Packet
type PacketHandler func(connection Connection, packet Packet)

// default ConnectionHandler for ProtoPacket
// default ConnectionHandler for Proto
type DefaultConnectionHandler struct {
// 注册消息的处理函数map
// registered map of PacketCommand and PacketHandler
Expand All @@ -59,6 +59,9 @@ type DefaultConnectionHandler struct {
// 心跳包构造函数(只对connector有效)
// heartBeat packet generator(only valid for connector)
heartBeatCreator ProtoMessageCreator
// 心跳包构造函数(只对connector有效)
// heartBeat packet generator(only valid for connector)
heartBeatPacketCreator PacketCreator
}

func (this *DefaultConnectionHandler) OnConnected(connection Connection, success bool) {
Expand All @@ -80,22 +83,20 @@ func (this *DefaultConnectionHandler) OnRecvPacket(connection Connection, packet
LogStack()
}
}()
if protoPacket, ok := packet.(*ProtoPacket); ok {
if packetHandler, ok2 := this.PacketHandlers[protoPacket.command]; ok2 {
if packetHandler != nil {
packetHandler(connection, protoPacket)
return
}
}
if this.UnRegisterHandler != nil {
this.UnRegisterHandler(connection, protoPacket)
if packetHandler, ok2 := this.PacketHandlers[packet.Command()]; ok2 {
if packetHandler != nil {
packetHandler(connection, packet)
return
}
}
if this.UnRegisterHandler != nil {
this.UnRegisterHandler(connection, packet)
}
}

func (this *DefaultConnectionHandler) CreateHeartBeatPacket(connection Connection) Packet {
if this.heartBeatCreator != nil {
return NewProtoPacket(this.heartBeatCommand, this.heartBeatCreator())
if this.heartBeatPacketCreator != nil {
return this.heartBeatPacketCreator()
}
return nil
}
Expand Down Expand Up @@ -130,10 +131,9 @@ func (this *DefaultConnectionHandler) GetPacketHandler(packetCommand PacketComma

// 注册心跳包(只对connector有效)
//
// register heartbeat PacketCommand and ProtoMessageCreator, only valid for connector
func (this *DefaultConnectionHandler) RegisterHeartBeat(packetCommand PacketCommand, creator ProtoMessageCreator) {
this.heartBeatCommand = packetCommand
this.heartBeatCreator = creator
// register heartBeatPacketCreator, only valid for connector
func (this *DefaultConnectionHandler) RegisterHeartBeat(heartBeatPacketCreator PacketCreator) {
this.heartBeatPacketCreator = heartBeatPacketCreator
}

// 未注册消息的处理函数
Expand Down
2 changes: 1 addition & 1 deletion packet_size_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestPacketSize(t *testing.T) {
listenAddress := "127.0.0.1:10002"
defaultCodec := NewProtoCodec(nil)
serverHandler := NewDefaultConnectionHandler(defaultCodec)
serverHandler.Register(PacketCommand(123), func(connection Connection, packet *ProtoPacket) {
serverHandler.Register(PacketCommand(123), func(connection Connection, packet Packet) {
testMessage := packet.Message().(*pb.TestMessage)
logger.Info("recv%v:%s", testMessage.I32, testMessage.Name)
}, new(pb.TestMessage))
Expand Down
2 changes: 1 addition & 1 deletion tcp_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (this *TcpConnection) Connect(address string) bool {
conn, err := net.DialTimeout("tcp", address, time.Second)
if err != nil {
atomic.StoreInt32(&this.isConnected, 0)
logger.Error("Connect failed %v: %v", this.GetConnectionId(), err.Error())
logger.Error("Connect failed %v: %v %v", this.GetConnectionId(), address, err.Error())
if this.handler != nil {
this.handler.OnConnected(this, false)
}
Expand Down
9 changes: 4 additions & 5 deletions tcp_connection_simple_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package gnet
import (
"context"
"github.com/fish-tennis/gnet/example/pb"
"google.golang.org/protobuf/proto"
"net"
"testing"
"time"
Expand Down Expand Up @@ -36,12 +35,12 @@ func TestTcpConnectionSimple(t *testing.T) {
codec.Register(PacketCommand(10086), nil)

connectionHandler := NewDefaultConnectionHandler(codec)
connectionHandler.RegisterHeartBeat(PacketCommand(pb.CmdTest_Cmd_HeartBeat), func() proto.Message {
return &pb.HeartBeatReq{
connectionHandler.RegisterHeartBeat(func() Packet {
return NewProtoPacket(PacketCommand(pb.CmdTest_Cmd_HeartBeat),&pb.HeartBeatReq{
Timestamp: GetCurrentTimeStamp(),
}
})
})
connectionHandler.SetUnRegisterHandler(func(connection Connection, packet *ProtoPacket) {
connectionHandler.SetUnRegisterHandler(func(connection Connection, packet Packet) {
streamStr := ""
if packet.GetStreamData() != nil {
streamStr = string(packet.GetStreamData())
Expand Down

0 comments on commit 0054fd9

Please sign in to comment.