Skip to content

Commit

Permalink
refactor server keepalive for hook access (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
mochi-co authored May 6, 2023
1 parent a734a0d commit 5225a35
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 64 deletions.
40 changes: 16 additions & 24 deletions clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,19 @@ type Will struct {

// State tracks the state of the client.
type ClientState struct {
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
outbound chan *packets.Packet // queue for pending outbound packets
endOnce sync.Once // only end once
isTakenOver uint32 // used to identify orphaned clients
packetID uint32 // the current highest packetID
open context.Context // indicate that the client is open for packet exchange
outboundQty int32 // number of messages currently in the outbound queue
keepalive uint16 // the number of seconds the connection can wait
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
outbound chan *packets.Packet // queue for pending outbound packets
endOnce sync.Once // only end once
isTakenOver uint32 // used to identify orphaned clients
packetID uint32 // the current highest packetID
open context.Context // indicate that the client is open for packet exchange
outboundQty int32 // number of messages currently in the outbound queue
Keepalive uint16 // the number of seconds the connection can wait
ServerKeepalive bool // keepalive was set by the server
}

// newClient returns a new instance of Client. This is almost exclusively used by Server
Expand All @@ -158,8 +159,8 @@ func newClient(c net.Conn, o *ops) *Client {
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
keepalive: defaultKeepalive,
open: context.Background(),
Keepalive: defaultKeepalive,
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
},
Properties: ClientProperties{
Expand All @@ -179,8 +180,6 @@ func newClient(c net.Conn, o *ops) *Client {
}
}

cl.refreshDeadline(cl.State.keepalive)

return cl
}

Expand All @@ -203,9 +202,9 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Properties.Clean = pk.Connect.Clean
cl.Properties.Props = pk.Properties.Copy(false)

cl.State.Keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max

cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum)

cl.ID = pk.Connect.ClientIdentifier
Expand All @@ -214,11 +213,6 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Properties.Props.AssignedClientID = cl.ID
}

cl.State.keepalive = cl.ops.options.Capabilities.ServerKeepAlive
if pk.Connect.Keepalive > 0 {
cl.State.keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
}

if pk.Connect.WillFlag {
cl.Properties.Will = Will{
Qos: pk.Connect.WillQos,
Expand All @@ -236,8 +230,6 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Properties.Will.Flag = 1 // atomic for checking
}
}

cl.refreshDeadline(cl.State.keepalive)
}

// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
Expand Down Expand Up @@ -336,7 +328,7 @@ func (cl *Client) Read(packetHandler ReadFn) error {
return nil
}

cl.refreshDeadline(cl.State.keepalive)
cl.refreshDeadline(cl.State.Keepalive)
fh := new(packets.FixedHeader)
err = cl.ReadFixedHeader(fh)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions clients_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func TestNewClient(t *testing.T) {
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.NotNil(t, cl.State.TopicAliases)
require.Equal(t, defaultKeepalive, cl.State.keepalive)
require.Equal(t, defaultKeepalive, cl.State.Keepalive)
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
require.NotNil(t, cl.Net.Conn)
require.NotNil(t, cl.Net.bconn)
Expand Down Expand Up @@ -165,7 +165,7 @@ func TestClientParseConnect(t *testing.T) {

cl.ParseConnect("tcp1", pk)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.Keepalive, cl.State.keepalive)
require.Equal(t, pk.Connect.Keepalive, cl.State.Keepalive)
require.Equal(t, pk.Connect.Clean, cl.Properties.Clean)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName)
Expand Down
8 changes: 4 additions & 4 deletions examples/debug/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ func main() {
l := server.Log.Level(zerolog.DebugLevel)
server.Log = &l

err := server.AddHook(new(auth.AllowHook), nil)
err := server.AddHook(new(debug.Hook), &debug.Options{
// ShowPacketData: true,
})
if err != nil {
log.Fatal(err)
}

err = server.AddHook(new(debug.Hook), &debug.Options{
// ShowPacketData: true,
})
err = server.AddHook(new(auth.AllowHook), nil)
if err != nil {
log.Fatal(err)
}
Expand Down
10 changes: 9 additions & 1 deletion examples/paho.testing/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ func main() {
}()

server := mqtt.New(nil)
server.Options.Capabilities.ServerKeepAlive = 60
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true

Expand Down Expand Up @@ -61,6 +60,7 @@ func (h *pahoAuthHook) ID() string {
func (h *pahoAuthHook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnectAuthenticate,
mqtt.OnConnect,
mqtt.OnACLCheck,
}, []byte{b})
}
Expand All @@ -72,3 +72,11 @@ func (h *pahoAuthHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet)
func (h *pahoAuthHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
return topic != "test/nosubscribe"
}

func (h *pahoAuthHook) OnConnect(cl *mqtt.Client, pk packets.Packet) {
// Handle paho test_server_keep_alive
if pk.Connect.Keepalive == 120 && pk.Connect.Clean {
cl.State.Keepalive = 60
cl.State.ServerKeepalive = true
}
}
62 changes: 37 additions & 25 deletions packets/tpackets.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ const (
TConnackAcceptedAdjustedExpiryInterval
TConnackMinMqtt5
TConnackMinCleanMqtt5
TConnackServerKeepalive
TConnackInvalidMinMqtt5
TConnackBadProtocolVersion
TConnackProtocolViolationNoSession
Expand Down Expand Up @@ -1085,25 +1086,22 @@ var TPacketData = map[byte]TPacketCases{
Desc: "accepted, no session, adjusted expiry interval mqtt5",
Primary: true,
RawBytes: []byte{
Connack << 4, 11, // fixed header
Connack << 4, 8, // fixed header
0, // Session present
CodeSuccess.Code,
8, // length
5, // length
17, 0, 0, 0, 120, // Session Expiry Interval (17)
19, 0, 10, // Server Keep Alive (19)
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 11,
Remaining: 8,
},
ReasonCode: CodeSuccess.Code,
Properties: Properties{
SessionExpiryInterval: uint32(120),
SessionExpiryIntervalFlag: true,
ServerKeepAlive: uint16(10),
ServerKeepAliveFlag: true,
},
},
},
Expand Down Expand Up @@ -1190,28 +1188,25 @@ var TPacketData = map[byte]TPacketCases{
Desc: "accepted min properties mqtt5",
Primary: true,
RawBytes: []byte{
Connack << 4, 16, // fixed header
Connack << 4, 13, // fixed header
1, // existing session
CodeSuccess.Code,
13, // Properties length
10, // Properties length
18, 0, 5, 'm', 'o', 'c', 'h', 'i', // Assigned Client ID (18)
19, 0, 20, // Server Keep Alive (19)
36, 1, // Maximum Qos (36)
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 16,
Remaining: 13,
},
SessionPresent: true,
ReasonCode: CodeSuccess.Code,
Properties: Properties{
ServerKeepAlive: uint16(20),
ServerKeepAliveFlag: true,
AssignedClientID: "mochi",
MaximumQos: byte(1),
MaximumQosFlag: true,
AssignedClientID: "mochi",
MaximumQos: byte(1),
MaximumQosFlag: true,
},
},
},
Expand All @@ -1220,11 +1215,10 @@ var TPacketData = map[byte]TPacketCases{
Desc: "accepted min properties mqtt5b",
Primary: true,
RawBytes: []byte{
Connack << 4, 6, // fixed header
Connack << 4, 3, // fixed header
0, // existing session
CodeSuccess.Code,
3, // Properties length
19, 0, 10, // server keepalive
0, // Properties length
},
Packet: &Packet{
ProtocolVersion: 5,
Expand All @@ -1234,6 +1228,27 @@ var TPacketData = map[byte]TPacketCases{
},
SessionPresent: false,
ReasonCode: CodeSuccess.Code,
},
},
{
Case: TConnackServerKeepalive,
Desc: "server set keepalive",
Primary: true,
RawBytes: []byte{
Connack << 4, 6, // fixed header
1, // existing session
CodeSuccess.Code,
3, // Properties length
19, 0, 10, // server keepalive
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 6,
},
SessionPresent: true,
ReasonCode: CodeSuccess.Code,
Properties: Properties{
ServerKeepAlive: uint16(10),
ServerKeepAliveFlag: true,
Expand All @@ -1245,26 +1260,23 @@ var TPacketData = map[byte]TPacketCases{
Desc: "failure min properties mqtt5",
Primary: true,
RawBytes: append([]byte{
Connack << 4, 26, // fixed header
Connack << 4, 23, // fixed header
0, // No existing session
ErrUnspecifiedError.Code,
// Properties
23, // length
19, 0, 20, // Server Keep Alive (19)
20, // length
31, 0, 17, // Reason String (31)
}, []byte(ErrUnspecifiedError.Reason)...),
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 25,
Remaining: 23,
},
SessionPresent: false,
ReasonCode: ErrUnspecifiedError.Code,
Properties: Properties{
ServerKeepAlive: uint16(20),
ServerKeepAliveFlag: true,
ReasonString: ErrUnspecifiedError.Reason,
ReasonString: ErrUnspecifiedError.Reason,
},
},
},
Expand Down
12 changes: 7 additions & 5 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ var (
WildcardSubAvailable: 1, // wildcard subscriptions are available
SubIDAvailable: 1, // subscription identifiers are available
SharedSubAvailable: 1, // shared subscriptions are available
ServerKeepAlive: 10, // default keepalive for clients
MinimumProtocolVersion: 3, // minimum supported mqtt version (3.0.0)
MaximumClientWritesPending: 1024 * 8, // maximum number of pending message writes for a client
}
Expand All @@ -61,7 +60,6 @@ type Capabilities struct {
maximumPacketID uint32 // unexported, used for testing only
ReceiveMaximum uint16
TopicAliasMaximum uint16
ServerKeepAlive uint16
SharedSubAvailable byte
MinimumProtocolVersion byte
Compatibilities Compatibilities
Expand Down Expand Up @@ -331,6 +329,7 @@ func (s *Server) attachClient(cl *Client, listener string) error {
}

s.hooks.OnConnect(cl, pk)
cl.refreshDeadline(cl.State.Keepalive)

if !s.hooks.OnConnectAuthenticate(cl, pk) { // [MQTT-3.1.4-2]
err := s.sendConnack(cl, packets.ErrBadUsernameOrPassword, false)
Expand Down Expand Up @@ -498,9 +497,12 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
// sendConnack returns a Connack packet to a client.
func (s *Server) sendConnack(cl *Client, reason packets.Code, present bool) error {
properties := packets.Properties{
ServerKeepAlive: s.Options.Capabilities.ServerKeepAlive, // [MQTT-3.1.2-21]
ServerKeepAliveFlag: true,
ReceiveMaximum: s.Options.Capabilities.ReceiveMaximum, // 3.2.2.3.3 Receive Maximum
ReceiveMaximum: s.Options.Capabilities.ReceiveMaximum, // 3.2.2.3.3 Receive Maximum
}

if cl.State.ServerKeepalive { // You can set this dynamically using the OnConnect hook.
properties.ServerKeepAlive = cl.State.Keepalive // [MQTT-3.1.2-21]
properties.ServerKeepAliveFlag = true
}

if reason.Code >= packets.ErrUnspecifiedError.Code {
Expand Down
21 changes: 18 additions & 3 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func TestServerNewClient(t *testing.T) {
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.NotNil(t, cl.State.TopicAliases)
require.Equal(t, defaultKeepalive, cl.State.keepalive)
require.Equal(t, defaultKeepalive, cl.State.Keepalive)
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
require.NotNil(t, cl.Net.Conn)
require.NotNil(t, cl.Net.bconn)
Expand Down Expand Up @@ -821,7 +821,6 @@ func TestServerSendConnack(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
s.Options.Capabilities.ServerKeepAlive = 20
s.Options.Capabilities.MaximumQos = 1
cl.Properties.Props = packets.Properties{
AssignedClientID: "mochi",
Expand All @@ -841,7 +840,6 @@ func TestServerSendConnackFailureReason(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
s.Options.Capabilities.ServerKeepAlive = 20
go func() {
err := s.sendConnack(cl, packets.ErrUnspecifiedError, true)
require.NoError(t, err)
Expand All @@ -853,6 +851,23 @@ func TestServerSendConnackFailureReason(t *testing.T) {
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackInvalidMinMqtt5).RawBytes, buf)
}

func TestServerSendConnackWithServerKeepalive(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
cl.State.Keepalive = 10
cl.State.ServerKeepalive = true
go func() {
err := s.sendConnack(cl, packets.CodeSuccess, true)
require.NoError(t, err)
w.Close()
}()

buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackServerKeepalive).RawBytes, buf)
}

func TestServerValidateConnect(t *testing.T) {
packet := *packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).Packet
invalidBitPacket := packet
Expand Down

0 comments on commit 5225a35

Please sign in to comment.