Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix message acks on wrong rabbitmq channels #4358

Merged
merged 15 commits into from
Dec 11, 2024
Merged
1 change: 1 addition & 0 deletions changelog.d/3-bug-fixes/rabbitmq-acks
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cannon does not attempt to restore a rabbitmq channel after it disconnects. This fixes a potential issue where a client would be able to ack a message on the wrong channel.
3 changes: 2 additions & 1 deletion charts/integration/templates/configmap.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ data:

rabbitmq:
host: rabbitmq
adminPort: 15671
port: 15671
adminPort: 15672
pcapriotti marked this conversation as resolved.
Show resolved Hide resolved

backendTwo:

Expand Down
113 changes: 97 additions & 16 deletions integration/test/Test/Events.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@ import API.Galley
import API.Gundeck
import qualified Control.Concurrent.Timeout as Timeout
import Control.Monad.Codensity
import Control.Monad.RWS (asks)
import Control.Monad.Trans.Class
import Control.Retry
import Data.ByteString.Char8 as B8
import Data.ByteString.Conversion (toByteString')
import qualified Data.Text as Text
import Data.Timeout
import qualified Network.HTTP.Client as HTTP
import qualified Network.WebSockets as WS
import Notifications
import SetupHelpers
import System.Environment (getEnv)
import Testlib.Prelude hiding (assertNoEvent)
import Testlib.ResourcePool (backendA)
import UnliftIO hiding (handle)

testConsumeEventsOneWebSocket :: (HasCallStack) => App ()
Expand All @@ -38,10 +43,10 @@ testConsumeEventsOneWebSocket = do
e %. "data.event.payload.0.type" `shouldMatch` "user.client-add"
e %. "data.event.payload.0.client.id" `shouldMatch` clientId
e %. "data.delivery_tag"
assertNoEvent ws
assertNoEvent_ ws

sendAck ws deliveryTag False
assertNoEvent ws
assertNoEvent_ ws

handle <- randomHandle
putHandle alice handle >>= assertSuccess
Expand Down Expand Up @@ -80,7 +85,7 @@ testConsumeEventsForDifferentUsers = do
e %. "data.event.payload.0.type" `shouldMatch` "user.client-add"
e %. "data.event.payload.0.client.id" `shouldMatch` clientId
e %. "data.delivery_tag"
assertNoEvent ws
assertNoEvent_ ws
sendAck ws deliveryTag False

testConsumeEventsWhileHavingLegacyClients :: (HasCallStack) => App ()
Expand Down Expand Up @@ -137,7 +142,7 @@ testConsumeEventsAcks = do
sendAck ws deliveryTag False

runCodensity (createEventsWebSocket alice clientId) $ \ws -> do
assertNoEvent ws
assertNoEvent_ ws

testConsumeEventsMultipleAcks :: (HasCallStack) => App ()
testConsumeEventsMultipleAcks = do
Expand All @@ -161,7 +166,7 @@ testConsumeEventsMultipleAcks = do
sendAck ws deliveryTag True

runCodensity (createEventsWebSocket alice clientId) $ \ws -> do
assertNoEvent ws
assertNoEvent_ ws

testConsumeEventsAckNewEventWithoutAckingOldOne :: (HasCallStack) => App ()
testConsumeEventsAckNewEventWithoutAckingOldOne = do
Expand Down Expand Up @@ -195,7 +200,7 @@ testConsumeEventsAckNewEventWithoutAckingOldOne = do
sendAck ws deliveryTagClientAdd False

runCodensity (createEventsWebSocket alice clientId) $ \ws -> do
assertNoEvent ws
assertNoEvent_ ws

testEventsDeadLettered :: (HasCallStack) => App ()
testEventsDeadLettered = do
Expand Down Expand Up @@ -229,7 +234,7 @@ testEventsDeadLettered = do
ackEvent ws e

-- We've consumed the whole queue.
assertNoEvent ws
assertNoEvent_ ws

testTransientEventsDoNotTriggerDeadLetters :: (HasCallStack) => App ()
testTransientEventsDoNotTriggerDeadLetters = do
Expand Down Expand Up @@ -257,7 +262,7 @@ testTransientEventsDoNotTriggerDeadLetters = do
sendTypingStatus alice selfConvId "started" >>= assertSuccess

runCodensity (createEventsWebSocket alice clientId) $ \ws -> do
assertNoEvent ws
assertNoEvent_ ws

testTransientEvents :: (HasCallStack) => App ()
testTransientEvents = do
Expand Down Expand Up @@ -296,7 +301,7 @@ testTransientEvents = do
e %. "data.event.payload.0.user.handle" `shouldMatch` handle
ackEvent ws e

assertNoEvent ws
assertNoEvent_ ws

testChannelLimit :: (HasCallStack) => App ()
testChannelLimit = withModifiedBackend
Expand All @@ -318,16 +323,43 @@ testChannelLimit = withModifiedBackend
lowerCodensity $ do
for_ clients $ \c -> do
ws <- createEventsWebSocket alice c
e <- Codensity $ \k -> assertEvent ws k
lift $ do
lift $ assertEvent ws $ \e -> do
e %. "data.event.payload.0.type" `shouldMatch` "user.client-add"
e %. "data.event.payload.0.client.id" `shouldMatch` c
e %. "data.delivery_tag"

-- the first client fails to connect because the server runs out of channels
do
ws <- createEventsWebSocket alice client0
lift $ assertNoEvent ws
lift $ assertNoEvent_ ws

testChannelKilled :: (HasCallStack) => App ()
testChannelKilled = do
alice <- randomUser OwnDomain def
[c1, c2] <-
replicateM 2
$ addClient alice def {acapabilities = Just ["consumable-notifications"]}
>>= getJSON 201
>>= (%. "id")
>>= asString

lowerCodensity $ do
ws <- createEventsWebSocket alice c1
lift $ do
assertEvent ws $ \e -> do
e %. "data.event.payload.0.type" `shouldMatch` "user.client-add"
e %. "data.event.payload.0.client.id" `shouldMatch` c1
ackEvent ws e

assertEvent ws $ \e -> do
e %. "data.event.payload.0.type" `shouldMatch` "user.client-add"
e %. "data.event.payload.0.client.id" `shouldMatch` c2

recoverAll
(constantDelay 500_000 <> limitRetries 10)
(const $ killConnection backendA)
pcapriotti marked this conversation as resolved.
Show resolved Hide resolved

noEvent <- assertNoEvent ws
noEvent `shouldMatch` WebSocketDied

----------------------------------------------------------------------
-- helpers
Expand Down Expand Up @@ -422,15 +454,24 @@ assertEvent ws expectations = do
addFailureContext ("event:\n" <> pretty)
$ expectations e

assertNoEvent :: (HasCallStack) => EventWebSocket -> App ()
data NoEvent = NoEvent | WebSocketDied

instance ToJSON NoEvent where
toJSON NoEvent = toJSON "no-event"
toJSON WebSocketDied = toJSON "web-socket-died"

assertNoEvent :: (HasCallStack) => EventWebSocket -> App NoEvent
assertNoEvent ws = do
timeout 1_000_000 (readChan ws.events) >>= \case
Nothing -> pure ()
Just (Left _) -> pure ()
Nothing -> pure NoEvent
Just (Left _) -> pure WebSocketDied
Just (Right e) -> do
eventJSON <- prettyJSON e
assertFailure $ "Did not expect event: \n" <> eventJSON

assertNoEvent_ :: (HasCallStack) => EventWebSocket -> App ()
assertNoEvent_ = void . assertNoEvent

consumeAllEvents :: EventWebSocket -> App ()
consumeAllEvents ws = do
timeout 1_000_000 (readChan ws.events) >>= \case
Expand All @@ -442,3 +483,43 @@ consumeAllEvents ws = do
Just (Right e) -> do
ackEvent ws e
consumeAllEvents ws

killConnection :: BackendResource -> App ()
killConnection backend = do
rabbitMqConfig <- asks (.rabbitMQConfig)
let url = "http://" <> rabbitMqConfig.host <> ":" <> show rabbitMqConfig.adminPort <> "/api/connections/"
userName <- liftIO $ getEnv "RABBITMQ_USERNAME"
password <- liftIO $ getEnv "RABBITMQ_PASSWORD"
name <- do
req <- liftIO $ HTTP.parseRequest url
bindResponse
( submit "GET" $ req
& HTTP.applyBasicAuth
(B8.pack userName)
(B8.pack password)
)
pcapriotti marked this conversation as resolved.
Show resolved Hide resolved
$ \resp -> do
resp.status `shouldMatchInt` 200
connections <- asList resp.json
connection <-
assertOne
=<< filterM
( \c -> do
name <- traverse asString =<< lookupField c "user_provided_name"
vhost <- c %. "vhost" & asString
-- We assume that there is only one connection, which is why we use "pool 0"
pure $ name == Just "pool 0" && vhost == backend.berVHost
)
connections
connection %. "name" & asString

do
req <- liftIO $ HTTP.parseRequest (url <> name)
bindResponse
( submit "DELETE" $ req
& HTTP.applyBasicAuth
(B8.pack userName)
(B8.pack password)
)
$ \resp -> do
resp.status `shouldMatchInt` 204
2 changes: 1 addition & 1 deletion integration/test/Testlib/ResourcePool.hs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ deleteAllRabbitMQQueues rc resource = do
RabbitMqAdminOpts
{ host = rc.host,
port = 0,
adminPort = fromIntegral rc.adminPort,
adminPort = fromIntegral rc.port,
vHost = T.pack resource.berVHost,
tls = Just $ RabbitMqTlsOpts Nothing True
}
Expand Down
2 changes: 2 additions & 0 deletions integration/test/Testlib/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ instance FromJSON DynamicBackendConfig

data RabbitMQConfig = RabbitMQConfig
{ host :: String,
port :: Word16,
adminPort :: Word16
}
deriving (Show)
Expand All @@ -99,6 +100,7 @@ instance FromJSON RabbitMQConfig where
withObject "RabbitMQConfig" $ \ob ->
RabbitMQConfig
<$> ob .: fromString "host"
<*> ob .: fromString "port"
<*> ob .: fromString "adminPort"

-- | Initialised once per testsuite.
Expand Down
17 changes: 12 additions & 5 deletions services/cannon/src/Cannon/RabbitMq.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import Control.Retry
import Data.ByteString.Conversion
import Data.List.Extra
import Data.Map qualified as Map
import Data.Text qualified as T
import Data.Timeout
import Imports hiding (threadDelay)
import Network.AMQP qualified as Q
Expand Down Expand Up @@ -59,7 +60,8 @@ data RabbitMqPool key = RabbitMqPool
data RabbitMqPoolOptions = RabbitMqPoolOptions
{ maxConnections :: Int,
maxChannels :: Int,
endpoint :: AmqpEndpoint
endpoint :: AmqpEndpoint,
retryEnabled :: Bool
}

createRabbitMqPool :: (Ord key) => RabbitMqPoolOptions -> Logger -> Codensity IO (RabbitMqPool key)
Expand Down Expand Up @@ -176,6 +178,7 @@ createConnection pool = mask_ $ do

openConnection :: RabbitMqPool key -> IO Q.Connection
openConnection pool = do
numConnections <- atomically $ length <$> readTVar pool.connections
(username, password) <- readCredsFromEnv
recovering
rabbitMqRetryPolicy
Expand All @@ -199,7 +202,9 @@ openConnection pool = do
],
Q.coVHost = pool.opts.endpoint.vHost,
Q.coAuth = [Q.plain username password],
Q.coTLSSettings = fmap Q.TLSCustom mTlsSettings
Q.coTLSSettings = fmap Q.TLSCustom mTlsSettings,
-- the name is used by tests to identify pool connections
Q.coName = Just ("pool " <> T.pack (show numConnections))
pcapriotti marked this conversation as resolved.
Show resolved Hide resolved
}
)

Expand Down Expand Up @@ -233,11 +238,11 @@ createChannel pool queue key = do
(_, Just (Q.ConnectionClosedException {})) -> do
Log.info pool.logger $
Log.msg (Log.val "RabbitMQ connection was closed unexpectedly")
pure True
pure pool.opts.retryEnabled
_ -> do
unless (fromException e == Just AsyncCancelled) $
logException pool.logger "RabbitMQ channel closed" e
pure True
pure pool.opts.retryEnabled
putMVar closedVar retry

let manageChannel = do
Expand All @@ -258,7 +263,9 @@ createChannel pool queue key = do
putMVar inner chan
void $ liftIO $ Q.consumeMsgs chan queue Q.Ack $ \(message, envelope) -> do
putMVar msgVar (Just (message, envelope))
takeMVar closedVar
retry <- takeMVar closedVar
void $ takeMVar inner
pure retry

when retry manageChannel

Expand Down
3 changes: 2 additions & 1 deletion services/cannon/src/Cannon/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ mkEnv external o cs l d conns p g t endpoint = do
RabbitMqPoolOptions
{ endpoint = endpoint,
maxConnections = o ^. rabbitMqMaxConnections,
maxChannels = o ^. rabbitMqMaxChannels
maxChannels = o ^. rabbitMqMaxChannels,
retryEnabled = False
}
pool <- createRabbitMqPool poolOpts l
let wsEnv =
Expand Down
3 changes: 2 additions & 1 deletion services/integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ dynamicBackends:

rabbitmq:
host: localhost
adminPort: 15671
port: 15671
adminPort: 15672

cassandra:
host: 127.0.0.1
Expand Down
Loading