From 6626091b76b82d215d8eeb2232a7a2c8de6d6617 Mon Sep 17 00:00:00 2001 From: Lev Brouk Date: Wed, 1 Nov 2023 17:14:49 -0700 Subject: [PATCH] [FIXED] MQTT rapid cluster CONNECT race to delete session Inline persistent sess notification processing PR feedback: nit _EMPTY_ PR feedback: more robust error handling, _EMPTY_ PR feedback: error handling --- server/mqtt.go | 115 ++++++++++++++++++++------------------------ server/mqtt_test.go | 61 ++++++++++++++++++++++- 2 files changed, 112 insertions(+), 64 deletions(-) diff --git a/server/mqtt.go b/server/mqtt.go index 7302722007d..4c0c96c66b4 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -155,6 +155,7 @@ const ( // while "$MQTT.JSA..SL." is for a stream lookup, etc... mqttJSAIdTokenPos = 3 mqttJSATokenPos = 4 + mqttJSAClientIDPos = 5 mqttJSAStreamCreate = "SC" mqttJSAStreamUpdate = "SU" mqttJSAStreamLookup = "SL" @@ -237,10 +238,9 @@ type mqttAccountSessionManager struct { sl *Sublist // sublist allowing to find retained messages for given subscription retmsgs map[string]*mqttRetainedMsgRef // retained messages jsa mqttJSA - rrmLastSeq uint64 // Restore retained messages expected last sequence - rrmDoneCh chan struct{} // To notify the caller that all retained messages have been loaded - sp *ipQueue[uint64] // Used for cluster-wide processing of session records being persisted - domainTk string // Domain (with trailing "."), or possibly empty. This is added to session subject. + rrmLastSeq uint64 // Restore retained messages expected last sequence + rrmDoneCh chan struct{} // To notify the caller that all retained messages have been loaded + domainTk string // Domain (with trailing "."), or possibly empty. This is added to session subject. } type mqttJSA struct { @@ -1109,7 +1109,6 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc nuid: nuid.New(), quitCh: quitCh, }, - sp: newIPQueue[uint64](s, qname+"sp"), } // TODO record domain name in as here @@ -1170,14 +1169,15 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc // This is a subscription that will process all JS API replies. We could split to // individual subscriptions if needed, but since there is a bit of common code, // that seemed like a good idea to be all in one place. - if err := as.createSubscription(jsa.rplyr+"*.*", + if err := as.createSubscription(jsa.rplyr+">", as.processJSAPIReplies, &sid, &subs); err != nil { return nil, err } // We will listen for replies to session persist requests so that we can // detect the use of a session with the same client ID anywhere in the cluster. - if err := as.createSubscription(mqttJSARepliesPrefix+"*."+mqttJSASessPersist+".*", + // `$MQTT.JSA.{js-id}.SP.{client-id-hash}.{uuid}` + if err := as.createSubscription(mqttJSARepliesPrefix+"*."+mqttJSASessPersist+".*.*", as.processSessionPersist, &sid, &subs); err != nil { return nil, err } @@ -1203,12 +1203,6 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc as.sendJSAPIrequests(s, c, accName, closeCh) }) - // Start the go routine that will handle network updates regarding sessions - s.startGoRoutine(func() { - defer s.grWG.Done() - as.sessPersistProcessing(closeCh) - }) - lookupStream := func(stream, txt string) (*StreamInfo, error) { si, err := jsa.lookupStream(stream) if err != nil { @@ -1454,7 +1448,7 @@ func (s *Server) mqttDetermineReplicas() int { ////////////////////////////////////////////////////////////////////////////// func (jsa *mqttJSA) newRequest(kind, subject string, hdr int, msg []byte) (interface{}, error) { - return jsa.newRequestEx(kind, subject, hdr, msg, mqttJSAPITimeout) + return jsa.newRequestEx(kind, subject, _EMPTY_, hdr, msg, mqttJSAPITimeout) } func (jsa *mqttJSA) prefixDomain(subject string) string { @@ -1467,19 +1461,24 @@ func (jsa *mqttJSA) prefixDomain(subject string) string { return subject } -func (jsa *mqttJSA) newRequestEx(kind, subject string, hdr int, msg []byte, timeout time.Duration) (interface{}, error) { +func (jsa *mqttJSA) newRequestEx(kind, subject, cidHash string, hdr int, msg []byte, timeout time.Duration) (interface{}, error) { + var sb strings.Builder jsa.mu.Lock() // Either we use nuid.Next() which uses a global lock, or our own nuid object, but // then it needs to be "write" protected. This approach will reduce across account // contention since we won't use the global nuid's lock. - var sb strings.Builder sb.WriteString(jsa.rplyr) sb.WriteString(kind) sb.WriteByte(btsep) + if cidHash != _EMPTY_ { + sb.WriteString(cidHash) + sb.WriteByte(btsep) + } sb.WriteString(jsa.nuid.Next()) - reply := sb.String() jsa.mu.Unlock() + reply := sb.String() + ch := make(chan interface{}, 1) jsa.replies.Store(reply, ch) @@ -1646,6 +1645,25 @@ func (jsa *mqttJSA) storeMsgWithKind(kind, subject string, headers int, msg []by return smr, smr.ToError() } +func (jsa *mqttJSA) storeSessionMsg(domainTk, cidHash string, hdr int, msg []byte) (*JSPubAckResponse, error) { + // Compute subject where the session is being stored + subject := mqttSessStreamSubjectPrefix + domainTk + cidHash + + // Passing cidHash will add it to the JS reply subject, so that we can use + // it in processSessionPersist. + smri, err := jsa.newRequestEx(mqttJSASessPersist, subject, cidHash, hdr, msg, mqttJSAPITimeout) + if err != nil { + return nil, err + } + smr := smri.(*JSPubAckResponse) + return smr, smr.ToError() +} + +func (jsa *mqttJSA) loadSessionMsg(domainTk, cidHash string) (*StoredMsg, error) { + subject := mqttSessStreamSubjectPrefix + domainTk + cidHash + return jsa.loadLastMsgFor(mqttSessStreamName, subject) +} + func (jsa *mqttJSA) deleteMsg(stream string, seq uint64, wait bool) error { dreq := JSApiMsgDeleteRequest{Seq: seq, NoErase: true} req, _ := json.Marshal(dreq) @@ -1817,6 +1835,7 @@ func (as *mqttAccountSessionManager) processSessionPersist(_ *subscription, pc * if tokenAt(subject, mqttJSAIdTokenPos) == as.jsa.id { return } + cIDHash := tokenAt(subject, mqttJSAClientIDPos) _, msg := pc.msgParts(rmsg) if len(msg) < LEN_CR_LF { return @@ -1839,18 +1858,6 @@ func (as *mqttAccountSessionManager) processSessionPersist(_ *subscription, pc * if ignore { return } - // We would need to lookup the message and that would be a request/reply, - // which we can't do in place here. So move that to a long-running routine - // that will process the session persist record. - as.sp.push(par.Sequence) -} - -func (as *mqttAccountSessionManager) processSessPersistRecord(seq uint64) { - smsg, err := as.jsa.loadMsg(mqttSessStreamName, seq) - if err != nil { - return - } - cIDHash := strings.TrimPrefix(smsg.Subject, mqttSessStreamSubjectPrefix+as.domainTk) as.mu.Lock() defer as.mu.Unlock() @@ -1861,7 +1868,7 @@ func (as *mqttAccountSessionManager) processSessPersistRecord(seq uint64) { // If our current session's stream sequence is higher, it means that this // update is stale, so we don't do anything here. sess.mu.Lock() - ignore := seq < sess.seq + ignore = par.Sequence < sess.seq sess.mu.Unlock() if ignore { return @@ -1881,28 +1888,6 @@ func (as *mqttAccountSessionManager) processSessPersistRecord(seq uint64) { sess.mu.Unlock() } -func (as *mqttAccountSessionManager) sessPersistProcessing(closeCh chan struct{}) { - as.mu.RLock() - sp := as.sp - quitCh := as.jsa.quitCh - as.mu.RUnlock() - - for { - select { - case <-sp.ch: - seqs := sp.pop() - for _, seq := range seqs { - as.processSessPersistRecord(seq) - } - sp.recycle(&seqs) - case <-closeCh: - return - case <-quitCh: - return - } - } -} - // Adds this client ID to the flappers map, and if needed start the timer // for map cleanup. // @@ -2417,8 +2402,7 @@ func (as *mqttAccountSessionManager) createOrRestoreSession(clientID string, opt } hash := getHash(clientID) - subject := mqttSessStreamSubjectPrefix + as.domainTk + hash - smsg, err := jsa.loadLastMsgFor(mqttSessStreamName, subject) + smsg, err := jsa.loadSessionMsg(as.domainTk, hash) if err != nil { if isErrorOtherThan(err, JSNoMessageFoundErr) { return formatError("loading session record", err) @@ -2434,6 +2418,7 @@ func (as *mqttAccountSessionManager) createOrRestoreSession(clientID string, opt if err := json.Unmarshal(smsg.Data, ps); err != nil { return formatError(fmt.Sprintf("unmarshal of session record at sequence %v", smsg.Sequence), err) } + // Restore this session (even if we don't own it), the caller will do the right thing. sess := mqttSessionCreate(jsa, clientID, hash, smsg.Sequence, opts) sess.domainTk = as.domainTk @@ -2479,7 +2464,7 @@ func (as *mqttAccountSessionManager) transferUniqueSessStreamsToMuxed(log *Serve }() jsa := &as.jsa - sni, err := jsa.newRequestEx(mqttJSAStreamNames, JSApiStreams, 0, nil, 5*time.Second) + sni, err := jsa.newRequestEx(mqttJSAStreamNames, JSApiStreams, _EMPTY_, 0, nil, 5*time.Second) if err != nil { log.Errorf("Unable to transfer MQTT session streams: %v", err) return @@ -2514,10 +2499,8 @@ func (as *mqttAccountSessionManager) transferUniqueSessStreamsToMuxed(log *Serve log.Warnf(" Unable to unmarshal the content of this stream, may not be a legitimate MQTT session stream, skipping") continue } - // Compute subject where the session is being stored - subject := mqttSessStreamSubjectPrefix + as.domainTk + getHash(ps.ID) // Store record to MQTT session stream - if _, err := jsa.storeMsgWithKind(mqttJSASessPersist, subject, 0, smsg.Data); err != nil { + if _, err := jsa.storeSessionMsg(as.domainTk, getHash(ps.ID), 0, smsg.Data); err != nil { log.Errorf(" Unable to transfer the session record: %v", err) return } @@ -2553,7 +2536,8 @@ func (as *mqttAccountSessionManager) transferRetainedToPerKeySubjectStream(log * } // Store the message again, this time with the new per-key subject. subject := mqttRetainedMsgsStreamSubject + rmsg.Subject - if _, err := jsa.storeMsgWithKind(mqttJSASessPersist, subject, 0, smsg.Data); err != nil { + + if _, err := jsa.storeMsg(subject, 0, smsg.Data); err != nil { log.Errorf(" Unable to transfer the retained message with sequence %d: %v", smsg.Sequence, err) errors++ continue @@ -2619,7 +2603,7 @@ func (sess *mqttSession) save() error { } b, _ := json.Marshal(&ps) - subject := mqttSessStreamSubjectPrefix + sess.domainTk + sess.idHash + domainTk, cidHash := sess.domainTk, sess.idHash seq := sess.seq sess.mu.Unlock() @@ -2637,7 +2621,7 @@ func (sess *mqttSession) save() error { b = bb.Bytes() } - resp, err := sess.jsa.storeMsgWithKind(mqttJSASessPersist, subject, hdr, b) + resp, err := sess.jsa.storeSessionMsg(domainTk, cidHash, hdr, b) if err != nil { return fmt.Errorf("unable to persist session %q (seq=%v): %v", ps.ID, seq, err) } @@ -2691,8 +2675,13 @@ func (sess *mqttSession) clear() error { } if seq > 0 { - if err := sess.jsa.deleteMsg(mqttSessStreamName, seq, true); err != nil { - return fmt.Errorf("unable to delete session %q record at sequence %v", id, seq) + err := sess.jsa.deleteMsg(mqttSessStreamName, seq, true) + // Ignore the various errors indicating that the message (or sequence) + // is already deleted, can happen in a cluster. + if isErrorOtherThan(err, JSSequenceNotFoundErrF) { + if isErrorOtherThan(err, JSStreamMsgDeleteFailedF) || !strings.Contains(err.Error(), ErrStoreMsgNotFound.Error()) { + return fmt.Errorf("unable to delete session %q record at sequence %v: %v", id, seq, err) + } } } return nil diff --git a/server/mqtt_test.go b/server/mqtt_test.go index c8a0de3fdaa..fb208249232 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -2992,6 +2992,58 @@ func TestMQTTCluster(t *testing.T) { } } +func testMQTTConnectDisconnect(t *testing.T, o *Options, clientID string, clean bool, found bool) { + t.Helper() + start := time.Now() + mc, r := testMQTTConnect(t, &mqttConnInfo{clientID: clientID, cleanSess: clean}, o.MQTT.Host, o.MQTT.Port) + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, found) + testMQTTDisconnectEx(t, mc, nil, false) + mc.Close() + if clean { + t.Logf("OK with server %v:%v, elapsed %v -- clean", o.MQTT.Host, o.MQTT.Port, time.Since(start)) + } else { + t.Logf("OK with server %v:%v, elapsed %v", o.MQTT.Host, o.MQTT.Port, time.Since(start)) + } +} + +func TestMQTTClusterConnectDisconnectClean(t *testing.T) { + nServers := 3 + cl := createJetStreamClusterWithTemplate(t, testMQTTGetClusterTemplaceNoLeaf(), "MQTT", nServers) + defer cl.shutdown() + + clientID := nuid.Next() + + // test runs a connect/disconnect against a random server in the cluster, as + // specified. + N := 100 + for n := 0; n < N; n++ { + testMQTTConnectDisconnect(t, cl.opts[rand.Intn(nServers)], clientID, true, false) + } +} + +func TestMQTTClusterConnectDisconnectPersist(t *testing.T) { + nServers := 3 + cl := createJetStreamClusterWithTemplate(t, testMQTTGetClusterTemplaceNoLeaf(), "MQTT", nServers) + defer cl.shutdown() + + clientID := nuid.Next() + + // test runs a connect/disconnect against a random server in the cluster, as + // specified. + N := 20 + for n := 0; n < N; n++ { + // First clean sessions on all servers + for i := 0; i < nServers; i++ { + testMQTTConnectDisconnect(t, cl.opts[i], clientID, true, false) + } + + testMQTTConnectDisconnect(t, cl.opts[0], clientID, false, false) + testMQTTConnectDisconnect(t, cl.opts[1], clientID, false, true) + testMQTTConnectDisconnect(t, cl.opts[2], clientID, false, true) + testMQTTConnectDisconnect(t, cl.opts[0], clientID, false, true) + } +} + func TestMQTTClusterRetainedMsg(t *testing.T) { cl := createJetStreamClusterWithTemplate(t, testMQTTGetClusterTemplaceNoLeaf(), "MQTT", 2) defer cl.shutdown() @@ -3859,6 +3911,11 @@ func TestMQTTPublishTopicErrors(t *testing.T) { } func testMQTTDisconnect(t testing.TB, c net.Conn, bw *bufio.Writer) { + t.Helper() + testMQTTDisconnectEx(t, c, bw, true) +} + +func testMQTTDisconnectEx(t testing.TB, c net.Conn, bw *bufio.Writer, wait bool) { t.Helper() w := &mqttWriter{} w.WriteByte(mqttPacketDisconnect) @@ -3869,7 +3926,9 @@ func testMQTTDisconnect(t testing.TB, c net.Conn, bw *bufio.Writer) { } else { c.Write(w.Bytes()) } - testMQTTExpectDisconnect(t, c) + if wait { + testMQTTExpectDisconnect(t, c) + } } func TestMQTTWill(t *testing.T) {