Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIXED] MQTT: rapid load-balanced (re-)CONNECT to cluster causes races #4734

Merged
merged 1 commit into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 52 additions & 63 deletions server/mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ const (
// while "$MQTT.JSA.<node id>.SL.<number>" is for a stream lookup, etc...
mqttJSAIdTokenPos = 3
mqttJSATokenPos = 4
mqttJSAClientIDPos = 5
mqttJSAStreamCreate = "SC"
mqttJSAStreamUpdate = "SU"
mqttJSAStreamLookup = "SL"
Expand Down Expand Up @@ -237,10 +238,9 @@ type mqttAccountSessionManager struct {
sl *Sublist // sublist allowing to find retained messages for given subscription
retmsgs map[string]*mqttRetainedMsgRef // retained messages
jsa mqttJSA
rrmLastSeq uint64 // Restore retained messages expected last sequence
rrmDoneCh chan struct{} // To notify the caller that all retained messages have been loaded
sp *ipQueue[uint64] // Used for cluster-wide processing of session records being persisted
domainTk string // Domain (with trailing "."), or possibly empty. This is added to session subject.
rrmLastSeq uint64 // Restore retained messages expected last sequence
rrmDoneCh chan struct{} // To notify the caller that all retained messages have been loaded
domainTk string // Domain (with trailing "."), or possibly empty. This is added to session subject.
}

type mqttJSA struct {
Expand Down Expand Up @@ -1109,7 +1109,6 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc
nuid: nuid.New(),
quitCh: quitCh,
},
sp: newIPQueue[uint64](s, qname+"sp"),
}
// TODO record domain name in as here

Expand Down Expand Up @@ -1170,14 +1169,15 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc
// This is a subscription that will process all JS API replies. We could split to
// individual subscriptions if needed, but since there is a bit of common code,
// that seemed like a good idea to be all in one place.
if err := as.createSubscription(jsa.rplyr+"*.*",
if err := as.createSubscription(jsa.rplyr+">",
as.processJSAPIReplies, &sid, &subs); err != nil {
return nil, err
}

// We will listen for replies to session persist requests so that we can
// detect the use of a session with the same client ID anywhere in the cluster.
if err := as.createSubscription(mqttJSARepliesPrefix+"*."+mqttJSASessPersist+".*",
// `$MQTT.JSA.{js-id}.SP.{client-id-hash}.{uuid}`
if err := as.createSubscription(mqttJSARepliesPrefix+"*."+mqttJSASessPersist+".*.*",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So does that mean that server with this fix will not be able to co-operate with other servers (say current v2.10.3)? That is, a current server that would persist a session with a reply subject on $MQTT.JSA.<serverId>.SP.<nuid> would not be received by a server with this fix. Maybe it's ok, but I am just raising this to make sure that you thought about it.

Copy link
Contributor Author

@levb levb Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, thanks for raising it. I did consider the compatibility issue and kinda punted on it, because of the "edge condition" nature of the use-case. I considered serializing the clientID into the uuid token using a different separator, but that felt too hacky for a permanent solution to a temporary edge-case. Nothing else that I could think of would make this PR backwards-operable, i.e. broadcasting to the <2.10.(N) servers in a way that they'd understand. I could easily add another listening subscription to pick up their messages, but that'd be 1-way only.

All in all, 1/5 leave as is and require that all servers in an MQTT cluster are upgraded/downgraded at approximately the same time. (Note for others, this is not affecting the session store itself, just the ACK change notifications.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having said that, I think using say, $MQTT.JSA.{js-id}.SP.{client-id-hash}_{uuid} would work just fine.

Copy link
Contributor Author

@levb levb Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kozlovic you agree with ^^? (leaving as is?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The advantage is that you could revert some of the create subscription to keep the same number of tokens. But you would need to do more processing to extract the client ID from the last token. Up to you.

as.processSessionPersist, &sid, &subs); err != nil {
return nil, err
}
Expand All @@ -1203,12 +1203,6 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc
as.sendJSAPIrequests(s, c, accName, closeCh)
})

// Start the go routine that will handle network updates regarding sessions
s.startGoRoutine(func() {
defer s.grWG.Done()
as.sessPersistProcessing(closeCh)
})

lookupStream := func(stream, txt string) (*StreamInfo, error) {
si, err := jsa.lookupStream(stream)
if err != nil {
Expand Down Expand Up @@ -1454,7 +1448,7 @@ func (s *Server) mqttDetermineReplicas() int {
//////////////////////////////////////////////////////////////////////////////

func (jsa *mqttJSA) newRequest(kind, subject string, hdr int, msg []byte) (interface{}, error) {
return jsa.newRequestEx(kind, subject, hdr, msg, mqttJSAPITimeout)
return jsa.newRequestEx(kind, subject, _EMPTY_, hdr, msg, mqttJSAPITimeout)
}

func (jsa *mqttJSA) prefixDomain(subject string) string {
Expand All @@ -1467,19 +1461,24 @@ func (jsa *mqttJSA) prefixDomain(subject string) string {
return subject
}

func (jsa *mqttJSA) newRequestEx(kind, subject string, hdr int, msg []byte, timeout time.Duration) (interface{}, error) {
func (jsa *mqttJSA) newRequestEx(kind, subject, cidHash string, hdr int, msg []byte, timeout time.Duration) (interface{}, error) {
var sb strings.Builder
jsa.mu.Lock()
// Either we use nuid.Next() which uses a global lock, or our own nuid object, but
// then it needs to be "write" protected. This approach will reduce across account
// contention since we won't use the global nuid's lock.
var sb strings.Builder
sb.WriteString(jsa.rplyr)
sb.WriteString(kind)
sb.WriteByte(btsep)
if cidHash != _EMPTY_ {
sb.WriteString(cidHash)
sb.WriteByte(btsep)
}
sb.WriteString(jsa.nuid.Next())
reply := sb.String()
jsa.mu.Unlock()

reply := sb.String()

ch := make(chan interface{}, 1)
jsa.replies.Store(reply, ch)

Expand Down Expand Up @@ -1646,6 +1645,25 @@ func (jsa *mqttJSA) storeMsgWithKind(kind, subject string, headers int, msg []by
return smr, smr.ToError()
}

func (jsa *mqttJSA) storeSessionMsg(domainTk, cidHash string, hdr int, msg []byte) (*JSPubAckResponse, error) {
// Compute subject where the session is being stored
subject := mqttSessStreamSubjectPrefix + domainTk + cidHash

// Passing cidHash will add it to the JS reply subject, so that we can use
// it in processSessionPersist.
smri, err := jsa.newRequestEx(mqttJSASessPersist, subject, cidHash, hdr, msg, mqttJSAPITimeout)
if err != nil {
return nil, err
}
smr := smri.(*JSPubAckResponse)
return smr, smr.ToError()
}

func (jsa *mqttJSA) loadSessionMsg(domainTk, cidHash string) (*StoredMsg, error) {
subject := mqttSessStreamSubjectPrefix + domainTk + cidHash
return jsa.loadLastMsgFor(mqttSessStreamName, subject)
}

func (jsa *mqttJSA) deleteMsg(stream string, seq uint64, wait bool) error {
dreq := JSApiMsgDeleteRequest{Seq: seq, NoErase: true}
req, _ := json.Marshal(dreq)
Expand Down Expand Up @@ -1817,6 +1835,7 @@ func (as *mqttAccountSessionManager) processSessionPersist(_ *subscription, pc *
if tokenAt(subject, mqttJSAIdTokenPos) == as.jsa.id {
return
}
cIDHash := tokenAt(subject, mqttJSAClientIDPos)
_, msg := pc.msgParts(rmsg)
if len(msg) < LEN_CR_LF {
return
Expand All @@ -1839,18 +1858,6 @@ func (as *mqttAccountSessionManager) processSessionPersist(_ *subscription, pc *
if ignore {
return
}
// We would need to lookup the message and that would be a request/reply,
// which we can't do in place here. So move that to a long-running routine
// that will process the session persist record.
as.sp.push(par.Sequence)
}

func (as *mqttAccountSessionManager) processSessPersistRecord(seq uint64) {
smsg, err := as.jsa.loadMsg(mqttSessStreamName, seq)
if err != nil {
return
}
cIDHash := strings.TrimPrefix(smsg.Subject, mqttSessStreamSubjectPrefix+as.domainTk)

as.mu.Lock()
defer as.mu.Unlock()
Expand All @@ -1861,7 +1868,7 @@ func (as *mqttAccountSessionManager) processSessPersistRecord(seq uint64) {
// If our current session's stream sequence is higher, it means that this
// update is stale, so we don't do anything here.
sess.mu.Lock()
ignore := seq < sess.seq
ignore = par.Sequence < sess.seq
sess.mu.Unlock()
if ignore {
return
Expand All @@ -1881,28 +1888,6 @@ func (as *mqttAccountSessionManager) processSessPersistRecord(seq uint64) {
sess.mu.Unlock()
}

func (as *mqttAccountSessionManager) sessPersistProcessing(closeCh chan struct{}) {
as.mu.RLock()
sp := as.sp
quitCh := as.jsa.quitCh
as.mu.RUnlock()

for {
select {
case <-sp.ch:
seqs := sp.pop()
for _, seq := range seqs {
as.processSessPersistRecord(seq)
}
sp.recycle(&seqs)
case <-closeCh:
return
case <-quitCh:
return
}
}
}

// Adds this client ID to the flappers map, and if needed start the timer
// for map cleanup.
//
Expand Down Expand Up @@ -2417,8 +2402,7 @@ func (as *mqttAccountSessionManager) createOrRestoreSession(clientID string, opt
}

hash := getHash(clientID)
subject := mqttSessStreamSubjectPrefix + as.domainTk + hash
smsg, err := jsa.loadLastMsgFor(mqttSessStreamName, subject)
smsg, err := jsa.loadSessionMsg(as.domainTk, hash)
if err != nil {
if isErrorOtherThan(err, JSNoMessageFoundErr) {
return formatError("loading session record", err)
Expand All @@ -2434,6 +2418,7 @@ func (as *mqttAccountSessionManager) createOrRestoreSession(clientID string, opt
if err := json.Unmarshal(smsg.Data, ps); err != nil {
return formatError(fmt.Sprintf("unmarshal of session record at sequence %v", smsg.Sequence), err)
}

// Restore this session (even if we don't own it), the caller will do the right thing.
sess := mqttSessionCreate(jsa, clientID, hash, smsg.Sequence, opts)
sess.domainTk = as.domainTk
Expand Down Expand Up @@ -2479,7 +2464,7 @@ func (as *mqttAccountSessionManager) transferUniqueSessStreamsToMuxed(log *Serve
}()

jsa := &as.jsa
sni, err := jsa.newRequestEx(mqttJSAStreamNames, JSApiStreams, 0, nil, 5*time.Second)
sni, err := jsa.newRequestEx(mqttJSAStreamNames, JSApiStreams, _EMPTY_, 0, nil, 5*time.Second)
if err != nil {
log.Errorf("Unable to transfer MQTT session streams: %v", err)
return
Expand Down Expand Up @@ -2514,10 +2499,8 @@ func (as *mqttAccountSessionManager) transferUniqueSessStreamsToMuxed(log *Serve
log.Warnf(" Unable to unmarshal the content of this stream, may not be a legitimate MQTT session stream, skipping")
continue
}
// Compute subject where the session is being stored
subject := mqttSessStreamSubjectPrefix + as.domainTk + getHash(ps.ID)
// Store record to MQTT session stream
if _, err := jsa.storeMsgWithKind(mqttJSASessPersist, subject, 0, smsg.Data); err != nil {
if _, err := jsa.storeSessionMsg(as.domainTk, getHash(ps.ID), 0, smsg.Data); err != nil {
log.Errorf(" Unable to transfer the session record: %v", err)
return
}
Expand Down Expand Up @@ -2553,7 +2536,8 @@ func (as *mqttAccountSessionManager) transferRetainedToPerKeySubjectStream(log *
}
// Store the message again, this time with the new per-key subject.
subject := mqttRetainedMsgsStreamSubject + rmsg.Subject
if _, err := jsa.storeMsgWithKind(mqttJSASessPersist, subject, 0, smsg.Data); err != nil {

levb marked this conversation as resolved.
Show resolved Hide resolved
if _, err := jsa.storeMsg(subject, 0, smsg.Data); err != nil {
log.Errorf(" Unable to transfer the retained message with sequence %d: %v", smsg.Sequence, err)
errors++
continue
Expand Down Expand Up @@ -2619,7 +2603,7 @@ func (sess *mqttSession) save() error {
}
b, _ := json.Marshal(&ps)

subject := mqttSessStreamSubjectPrefix + sess.domainTk + sess.idHash
domainTk, cidHash := sess.domainTk, sess.idHash
seq := sess.seq
sess.mu.Unlock()

Expand All @@ -2637,7 +2621,7 @@ func (sess *mqttSession) save() error {
b = bb.Bytes()
}

resp, err := sess.jsa.storeMsgWithKind(mqttJSASessPersist, subject, hdr, b)
resp, err := sess.jsa.storeSessionMsg(domainTk, cidHash, hdr, b)
if err != nil {
return fmt.Errorf("unable to persist session %q (seq=%v): %v", ps.ID, seq, err)
}
Expand Down Expand Up @@ -2691,8 +2675,13 @@ func (sess *mqttSession) clear() error {
}

if seq > 0 {
if err := sess.jsa.deleteMsg(mqttSessStreamName, seq, true); err != nil {
return fmt.Errorf("unable to delete session %q record at sequence %v", id, seq)
err := sess.jsa.deleteMsg(mqttSessStreamName, seq, true)
// Ignore the various errors indicating that the message (or sequence)
// is already deleted, can happen in a cluster.
if isErrorOtherThan(err, JSSequenceNotFoundErrF) {
if isErrorOtherThan(err, JSStreamMsgDeleteFailedF) || !strings.Contains(err.Error(), ErrStoreMsgNotFound.Error()) {
return fmt.Errorf("unable to delete session %q record at sequence %v: %v", id, seq, err)
}
}
}
return nil
Expand Down
61 changes: 60 additions & 1 deletion server/mqtt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2992,6 +2992,58 @@ func TestMQTTCluster(t *testing.T) {
}
}

func testMQTTConnectDisconnect(t *testing.T, o *Options, clientID string, clean bool, found bool) {
t.Helper()
start := time.Now()
mc, r := testMQTTConnect(t, &mqttConnInfo{clientID: clientID, cleanSess: clean}, o.MQTT.Host, o.MQTT.Port)
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, found)
testMQTTDisconnectEx(t, mc, nil, false)
mc.Close()
if clean {
t.Logf("OK with server %v:%v, elapsed %v -- clean", o.MQTT.Host, o.MQTT.Port, time.Since(start))
} else {
t.Logf("OK with server %v:%v, elapsed %v", o.MQTT.Host, o.MQTT.Port, time.Since(start))
}
}

func TestMQTTClusterConnectDisconnectClean(t *testing.T) {
nServers := 3
cl := createJetStreamClusterWithTemplate(t, testMQTTGetClusterTemplaceNoLeaf(), "MQTT", nServers)
defer cl.shutdown()

clientID := nuid.Next()

// test runs a connect/disconnect against a random server in the cluster, as
// specified.
N := 100
for n := 0; n < N; n++ {
testMQTTConnectDisconnect(t, cl.opts[rand.Intn(nServers)], clientID, true, false)
}
}

func TestMQTTClusterConnectDisconnectPersist(t *testing.T) {
nServers := 3
cl := createJetStreamClusterWithTemplate(t, testMQTTGetClusterTemplaceNoLeaf(), "MQTT", nServers)
defer cl.shutdown()

clientID := nuid.Next()

// test runs a connect/disconnect against a random server in the cluster, as
// specified.
N := 20
for n := 0; n < N; n++ {
// First clean sessions on all servers
for i := 0; i < nServers; i++ {
testMQTTConnectDisconnect(t, cl.opts[i], clientID, true, false)
}

testMQTTConnectDisconnect(t, cl.opts[0], clientID, false, false)
testMQTTConnectDisconnect(t, cl.opts[1], clientID, false, true)
testMQTTConnectDisconnect(t, cl.opts[2], clientID, false, true)
testMQTTConnectDisconnect(t, cl.opts[0], clientID, false, true)
}
}

func TestMQTTClusterRetainedMsg(t *testing.T) {
cl := createJetStreamClusterWithTemplate(t, testMQTTGetClusterTemplaceNoLeaf(), "MQTT", 2)
defer cl.shutdown()
Expand Down Expand Up @@ -3859,6 +3911,11 @@ func TestMQTTPublishTopicErrors(t *testing.T) {
}

func testMQTTDisconnect(t testing.TB, c net.Conn, bw *bufio.Writer) {
t.Helper()
testMQTTDisconnectEx(t, c, bw, true)
}

func testMQTTDisconnectEx(t testing.TB, c net.Conn, bw *bufio.Writer, wait bool) {
t.Helper()
w := &mqttWriter{}
w.WriteByte(mqttPacketDisconnect)
Expand All @@ -3869,7 +3926,9 @@ func testMQTTDisconnect(t testing.TB, c net.Conn, bw *bufio.Writer) {
} else {
c.Write(w.Bytes())
}
testMQTTExpectDisconnect(t, c)
if wait {
testMQTTExpectDisconnect(t, c)
}
}

func TestMQTTWill(t *testing.T) {
Expand Down