diff --git a/changelog.d/5-internal/mls-robust-consume b/changelog.d/5-internal/mls-robust-consume new file mode 100644 index 00000000000..b57f1e773e0 --- /dev/null +++ b/changelog.d/5-internal/mls-robust-consume @@ -0,0 +1 @@ +More robust consuming of MLS messages: the behaviour of `sendAndConsumeMessage` and `sendAndConsumeCommitBundle` is changed to actually wait for those messages on the client's websocket diff --git a/integration/test/API/Galley.hs b/integration/test/API/Galley.hs index 87e5042bc6d..66755cc8e70 100644 --- a/integration/test/API/Galley.hs +++ b/integration/test/API/Galley.hs @@ -1,4 +1,5 @@ {-# LANGUAGE OverloadedLabels #-} +{-# OPTIONS_GHC -Wno-ambiguous-fields #-} module API.Galley where diff --git a/integration/test/MLS/Util.hs b/integration/test/MLS/Util.hs index 94c2d520f21..68feacee6db 100644 --- a/integration/test/MLS/Util.hs +++ b/integration/test/MLS/Util.hs @@ -7,9 +7,11 @@ import API.Galley import Control.Concurrent.Async hiding (link) import Control.Monad import Control.Monad.Catch +import Control.Monad.Codensity import Control.Monad.Cont import Control.Monad.Reader import Control.Monad.Trans.Maybe +import Data.Aeson qualified as A import Data.Aeson qualified as Aeson import Data.ByteString qualified as BS import Data.ByteString.Base64 qualified as Base64 @@ -26,6 +28,7 @@ import Data.Traversable import Data.UUID qualified as UUID import Data.UUID.V4 qualified as UUIDV4 import GHC.Stack +import Notifications import System.Directory import System.Exit import System.FilePath @@ -501,15 +504,75 @@ createExternalCommit cid mgi = do groupInfo = Just newPgs } --- | Make all member clients consume a given message. -consumeMessage :: HasCallStack => MessagePackage -> App () -consumeMessage msg = do +data MLSNotificationTag = MLSNotificationMessageTag | MLSNotificationWelcomeTag + deriving (Show, Eq, Ord) + +-- | Extract a conversation ID (including an optional subconversation) from an +-- event object. +eventSubConv :: HasCallStack => MakesValue event => event -> App Value +eventSubConv event = do + sub <- lookupField event "subconv" + conv <- event %. "qualified_conversation" + objSubConvObject $ + object + [ "parent_qualified_id" .= conv, + "subconv_id" .= sub + ] + +consumingMessages :: HasCallStack => MessagePackage -> Codensity App () +consumingMessages mp = Codensity $ \k -> do mls <- getMLSState - for_ (Set.delete msg.sender mls.members) $ \cid -> - consumeMessage1 cid msg.message + -- clients that should receive the message itself + let oldClients = Set.delete mp.sender mls.members + -- clients that should receive a welcome message + let newClients = Set.delete mp.sender mls.newMembers + -- all clients that should receive some MLS notification, together with the + -- expected notification tag + let clients = + map (,MLSNotificationMessageTag) (toList oldClients) + <> map (,MLSNotificationWelcomeTag) (toList newClients) + + let newUsers = + Set.delete mp.sender.user $ + Set.difference + (Set.map (.user) newClients) + (Set.map (.user) oldClients) + withWebSockets (map fst clients) $ \wss -> do + r <- k () + + -- if the conversation is actually MLS (and not mixed), pick one client for + -- each new user and wait for its join event + when (mls.protocol == MLSProtocolMLS) $ + traverse_ + (awaitMatch 10 isMemberJoinNotif) + ( flip Map.restrictKeys newUsers + . Map.mapKeys ((.user) . fst) + . Map.fromList + . toList + $ zip clients wss + ) + + -- at this point we know that every new user has been added to the + -- conversation + for_ (zip clients wss) $ \((cid, t), ws) -> case t of + MLSNotificationMessageTag -> void $ consumeMessage cid (Just mp) ws + MLSNotificationWelcomeTag -> consumeWelcome cid mp ws + pure r + +-- | Get a single MLS message from a websocket and consume it. Return a JSON +-- representation of the message. +consumeMessage :: HasCallStack => ClientIdentity -> Maybe MessagePackage -> WebSocket -> App Value +consumeMessage cid mmp ws = do + mls <- getMLSState + notif <- awaitMatch 10 isNewMLSMessageNotif ws + event <- notif %. "payload.0" + + for_ mmp $ \mp -> do + shouldMatch (eventSubConv event) (fromMaybe A.Null mls.convId) + shouldMatch (event %. "from") mp.sender.user + shouldMatch (event %. "data") (B8.unpack (Base64.encode mp.message)) -consumeMessage1 :: HasCallStack => ClientIdentity -> ByteString -> App () -consumeMessage1 cid msg = + msgData <- event %. "data" & asByteString void $ mlscli cid @@ -520,52 +583,72 @@ consumeMessage1 cid msg = "", "-" ] - (Just msg) + (Just msgData) + showMessage cid msgData --- | Send an MLS message and simulate clients receiving it. If the message is a --- commit, the 'sendAndConsumeCommit' function should be used instead. +-- | Send an MLS message, wait for clients to receive it, then consume it on +-- the client side. If the message is a commit, the +-- 'sendAndConsumeCommitBundle' function should be used instead. sendAndConsumeMessage :: HasCallStack => MessagePackage -> App Value -sendAndConsumeMessage mp = do - r <- postMLSMessage mp.sender mp.message >>= getJSON 201 - consumeMessage mp - pure r +sendAndConsumeMessage mp = lowerCodensity $ do + consumingMessages mp + lift $ postMLSMessage mp.sender mp.message >>= getJSON 201 --- | Send an MLS commit bundle, simulate clients receiving it, and update the --- test state accordingly. +-- | Send an MLS commit bundle, wait for clients to receive it, consume it, and +-- update the test state accordingly. sendAndConsumeCommitBundle :: HasCallStack => MessagePackage -> App Value sendAndConsumeCommitBundle mp = do - resp <- postMLSCommitBundle mp.sender (mkBundle mp) >>= getJSON 201 - consumeMessage mp - traverse_ consumeWelcome mp.welcome - - -- increment epoch and add new clients - modifyMLSState $ \mls -> - mls - { epoch = epoch mls + 1, - members = members mls <> newMembers mls, - newMembers = mempty - } + lowerCodensity $ do + consumingMessages mp + lift $ do + r <- postMLSCommitBundle mp.sender (mkBundle mp) >>= getJSON 201 + + -- if the sender is a new member (i.e. it's an external commit), then + -- process the welcome message directly + do + mls <- getMLSState + when (Set.member mp.sender mls.newMembers) $ + traverse_ (fromWelcome mp.sender) mp.welcome + + -- increment epoch and add new clients + modifyMLSState $ \mls -> + mls + { epoch = epoch mls + 1, + members = members mls <> newMembers mls, + newMembers = mempty + } - pure resp + pure r -consumeWelcome :: HasCallStack => ByteString -> App () -consumeWelcome welcome = do +consumeWelcome :: HasCallStack => ClientIdentity -> MessagePackage -> WebSocket -> App () +consumeWelcome cid mp ws = do mls <- getMLSState - for_ mls.newMembers $ \cid -> do - gs <- getClientGroupState cid - assertBool - "Existing clients in a conversation should not consume welcomes" - (isNothing gs.group) - void $ - mlscli - cid - [ "group", - "from-welcome", - "--group-out", - "", - "-" - ] - (Just welcome) + notif <- awaitMatch 10 isWelcomeNotif ws + event <- notif %. "payload.0" + + shouldMatch (eventSubConv event) (fromMaybe A.Null mls.convId) + shouldMatch (event %. "from") mp.sender.user + shouldMatch (event %. "data") (fmap (B8.unpack . Base64.encode) mp.welcome) + + welcome <- event %. "data" & asByteString + gs <- getClientGroupState cid + assertBool + "Existing clients in a conversation should not consume welcomes" + (isNothing gs.group) + fromWelcome cid welcome + +fromWelcome :: ClientIdentity -> ByteString -> App () +fromWelcome cid welcome = + void $ + mlscli + cid + [ "group", + "from-welcome", + "--group-out", + "", + "-" + ] + (Just welcome) readWelcome :: FilePath -> IO (Maybe ByteString) readWelcome fp = runMaybeT $ do diff --git a/integration/test/Notifications.hs b/integration/test/Notifications.hs index d584407e89e..58edd2ec733 100644 --- a/integration/test/Notifications.hs +++ b/integration/test/Notifications.hs @@ -69,6 +69,9 @@ isNewMessageNotif n = fieldEquals n "payload.0.type" "conversation.otr-message-a isNewMLSMessageNotif :: MakesValue a => a -> App Bool isNewMLSMessageNotif n = fieldEquals n "payload.0.type" "conversation.mls-message-add" +isWelcomeNotif :: MakesValue a => a -> App Bool +isWelcomeNotif n = fieldEquals n "payload.0.type" "conversation.mls-welcome" + isMemberJoinNotif :: MakesValue a => a -> App Bool isMemberJoinNotif n = fieldEquals n "payload.0.type" "conversation.member-join" diff --git a/integration/test/SetupHelpers.hs b/integration/test/SetupHelpers.hs index 26f6db76e18..95694cfaf92 100644 --- a/integration/test/SetupHelpers.hs +++ b/integration/test/SetupHelpers.hs @@ -122,6 +122,8 @@ simpleMixedConversationSetup secondDomain = do bindResponse (putConversationProtocol bob conv "mixed") $ \resp -> do resp.status `shouldMatchInt` 200 + modifyMLSState $ \mls -> mls {protocol = MLSProtocolMixed} + conv' <- getConversation alice conv >>= getJSON 200 pure (alice, bob, conv') diff --git a/integration/test/Test/MLS.hs b/integration/test/Test/MLS.hs index 5a7cc4a9163..f9c966f78e9 100644 --- a/integration/test/Test/MLS.hs +++ b/integration/test/Test/MLS.hs @@ -1,4 +1,4 @@ -{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-} +{-# OPTIONS_GHC -Wno-incomplete-uni-patterns -Wno-ambiguous-fields #-} module Test.MLS where @@ -90,6 +90,7 @@ testMixedProtocolUpgrade secondDomain = do resp.status `shouldMatchInt` 200 resp.json %. "conversation" `shouldMatch` (qcnv %. "id") resp.json %. "data.protocol" `shouldMatch` "mixed" + modifyMLSState $ \mls -> mls {protocol = MLSProtocolMixed} for_ websockets $ \ws -> do n <- awaitMatch 3 (\value -> nPayload value %. "type" `isEqual` "conversation.protocol-update") ws @@ -130,6 +131,7 @@ testMixedProtocolAddUsers secondDomain = do bindResponse (putConversationProtocol bob qcnv "mixed") $ \resp -> do resp.status `shouldMatchInt` 200 + modifyMLSState $ \mls -> mls {protocol = MLSProtocolMixed} [alice1, bob1] <- traverse (createMLSClient def) [alice, bob] @@ -158,6 +160,7 @@ testMixedProtocolUserLeaves secondDomain = do bindResponse (putConversationProtocol bob qcnv "mixed") $ \resp -> do resp.status `shouldMatchInt` 200 + modifyMLSState $ \mls -> mls {protocol = MLSProtocolMixed} [alice1, bob1] <- traverse (createMLSClient def) [alice, bob] @@ -193,6 +196,7 @@ testMixedProtocolAddPartialClients secondDomain = do bindResponse (putConversationProtocol bob qcnv "mixed") $ \resp -> do resp.status `shouldMatchInt` 200 + modifyMLSState $ \mls -> mls {protocol = MLSProtocolMixed} [alice1, bob1, bob2] <- traverse (createMLSClient def) [alice, bob, bob] @@ -231,6 +235,7 @@ testMixedProtocolRemovePartialClients secondDomain = do bindResponse (putConversationProtocol bob qcnv "mixed") $ \resp -> do resp.status `shouldMatchInt` 200 + modifyMLSState $ \mls -> mls {protocol = MLSProtocolMixed} [alice1, bob1, bob2] <- traverse (createMLSClient def) [alice, bob, bob] @@ -256,6 +261,7 @@ testMixedProtocolAppMessagesAreDenied secondDomain = do bindResponse (putConversationProtocol bob qcnv "mixed") $ \resp -> do resp.status `shouldMatchInt` 200 + modifyMLSState $ \mls -> mls {protocol = MLSProtocolMixed} [alice1, bob1] <- traverse (createMLSClient def) [alice, bob] @@ -302,6 +308,7 @@ testMLSProtocolUpgrade secondDomain = do withWebSockets [alice1, bob1] $ \wss -> do bindResponse (putConversationProtocol bob conv "mls") $ \resp -> do resp.status `shouldMatchInt` 200 + modifyMLSState $ \mls -> mls {protocol = MLSProtocolMLS} for_ wss $ \ws -> do n <- awaitMatch 3 isNewMLSMessageNotif ws msg <- asByteString (nPayload n %. "data") >>= showMessage alice1 @@ -399,47 +406,17 @@ testCreateSubConvProteus = do bindResponse (getSubConversation alice conv "conference") $ \resp -> resp.status `shouldMatchInt` 404 --- FUTUREWORK: New clients should be adding themselves via external commits, and --- they shouldn't be added by another client. Change the test so external --- commits are used. testSelfConversation :: App () testSelfConversation = do alice <- randomUser OwnDomain def creator : others <- traverse (createMLSClient def) (replicate 3 alice) traverse_ uploadNewKeyPackage others - (_, cnv) <- createSelfGroup creator - commit <- createAddCommit creator [alice] - welcome <- assertOne (toList commit.welcome) + void $ createSelfGroup creator + void $ createAddCommit creator [alice] >>= sendAndConsumeCommitBundle - withWebSockets others $ \wss -> do - void $ sendAndConsumeCommitBundle commit - let isWelcome n = nPayload n %. "type" `isEqual` "conversation.mls-welcome" - for_ wss $ \ws -> do - n <- awaitMatch 3 isWelcome ws - shouldMatch (nPayload n %. "conversation") (objId cnv) - shouldMatch (nPayload n %. "from") (objId alice) - shouldMatch (nPayload n %. "data") (B8.unpack (Base64.encode welcome)) - -testJoinSubConv :: App () -testJoinSubConv = do - [alice, bob] <- createAndConnectUsers [OwnDomain, OwnDomain] - [alice1, bob1, bob2] <- traverse (createMLSClient def) [alice, bob, bob] - traverse_ uploadNewKeyPackage [bob1, bob2] - (_, qcnv) <- createNewGroup alice1 - void $ createAddCommit alice1 [bob] >>= sendAndConsumeCommitBundle - void $ createSubConv bob1 "conference" - - -- bob adds his first client to the subconversation - void $ createPendingProposalCommit bob1 >>= sendAndConsumeCommitBundle - sub' <- getSubConversation bob qcnv "conference" >>= getJSON 200 - do - tm <- sub' %. "epoch_timestamp" - assertBool "Epoch timestamp should not be null" (tm /= Null) - - -- now alice joins with her own client - void $ - createExternalCommit alice1 Nothing - >>= sendAndConsumeCommitBundle + newClient <- createMLSClient def alice + void $ uploadNewKeyPackage newClient + void $ createExternalCommit newClient Nothing >>= sendAndConsumeCommitBundle -- | FUTUREWORK: Don't allow partial adds, not even in the first commit testFirstCommitAllowsPartialAdds :: HasCallStack => App () @@ -505,6 +482,7 @@ testAdminRemovesUserFromConv :: HasCallStack => App () testAdminRemovesUserFromConv = do [alice, bob] <- createAndConnectUsers [OwnDomain, OwnDomain] [alice1, bob1, bob2] <- traverse (createMLSClient def) [alice, bob, bob] + void $ createWireClient bob traverse_ uploadNewKeyPackage [bob1, bob2] (gid, qcnv) <- createNewGroup alice1 @@ -520,15 +498,16 @@ testAdminRemovesUserFromConv = do bobQid <- bob %. "qualified_id" shouldMatch members [bobQid] - convs <- getAllConvs bob - convIds <- traverse (%. "qualified_id") convs - clients <- bindResponse (getGroupClients alice gid) $ \resp -> do - resp.status `shouldMatchInt` 200 - resp.json %. "client_ids" & asList - void $ assertOne clients - assertBool - "bob is not longer part of conversation after the commit" - (qcnv `notElem` convIds) + do + convs <- getAllConvs bob + convIds <- traverse (%. "qualified_id") convs + clients <- bindResponse (getGroupClients alice gid) $ \resp -> do + resp.status `shouldMatchInt` 200 + resp.json %. "client_ids" & asList + void $ assertOne clients + assertBool + "bob is not longer part of conversation after the commit" + (qcnv `notElem` convIds) testLocalWelcome :: HasCallStack => App () testLocalWelcome = do diff --git a/integration/test/Test/MLS/Message.hs b/integration/test/Test/MLS/Message.hs index e36115ab934..762278f8de7 100644 --- a/integration/test/Test/MLS/Message.hs +++ b/integration/test/Test/MLS/Message.hs @@ -19,6 +19,7 @@ module Test.MLS.Message where +import API.Galley import API.Gundeck import MLS.Util import Notifications @@ -69,7 +70,10 @@ testAppMessageSomeReachable = do awaitMatch 10 isMemberJoinNotif ws pure alice1 - void $ createApplicationMessage alice1 "hi, bob!" >>= sendAndConsumeMessage + -- charlie isn't able to receive this message, so we make sure we can post it + -- successfully, but not attempt to consume it + mp <- createApplicationMessage alice1 "hi, bob!" + void $ postMLSMessage mp.sender mp.message >>= getJSON 201 testMessageNotifications :: HasCallStack => Domain -> App () testMessageNotifications bobDomain = do diff --git a/integration/test/Test/MLS/SubConversation.hs b/integration/test/Test/MLS/SubConversation.hs index ed5aa95c3d4..4ce961bab03 100644 --- a/integration/test/Test/MLS/SubConversation.hs +++ b/integration/test/Test/MLS/SubConversation.hs @@ -12,11 +12,11 @@ testJoinSubConv = do [alice1, bob1, bob2] <- traverse (createMLSClient def) [alice, bob, bob] traverse_ uploadNewKeyPackage [bob1, bob2] (_, qcnv) <- createNewGroup alice1 + void $ createAddCommit alice1 [bob] >>= sendAndConsumeCommitBundle createSubConv bob1 "conference" -- bob adds his first client to the subconversation - void $ createPendingProposalCommit bob1 >>= sendAndConsumeCommitBundle sub' <- getSubConversation bob qcnv "conference" >>= getJSON 200 do tm <- sub' %. "epoch_timestamp" @@ -36,13 +36,10 @@ testDeleteParentOfSubConv secondDomain = do [alice1, bob1] <- traverse (createMLSClient def) [alice, bob] traverse_ uploadNewKeyPackage [alice1, bob1] (_, qcnv) <- createNewGroup alice1 - withWebSocket bob $ \ws -> do - void $ createAddCommit alice1 [bob] >>= sendAndConsumeCommitBundle - void $ awaitMatch 10 isMemberJoinNotif ws + void $ createAddCommit alice1 [bob] >>= sendAndConsumeCommitBundle -- bob creates a subconversation and adds his own client createSubConv bob1 "conference" - void $ createPendingProposalCommit bob1 >>= sendAndConsumeCommitBundle -- alice joins with her own client void $ createExternalCommit alice1 Nothing >>= sendAndConsumeCommitBundle @@ -136,12 +133,9 @@ testLeaveSubConv variant = do leaveCurrentConv firstLeaver for_ (zip others wss) $ \(cid, ws) -> do - notif <- awaitMatch 10 isNewMLSMessageNotif ws - msgData <- notif %. "payload.0.data" & asByteString - msg <- showMessage alice1 msgData + msg <- consumeMessage cid Nothing ws msg %. "message.content.body.Proposal.Remove.removed" `shouldMatchInt` idxFirstLeaver msg %. "message.content.sender.External" `shouldMatchInt` 0 - consumeMessage1 cid msgData withWebSockets (tail others) $ \wss -> do -- a member commits the pending proposal @@ -164,12 +158,9 @@ testLeaveSubConv variant = do leaveCurrentConv charlie1 for_ (zip others' wss) $ \(cid, ws) -> do - notif <- awaitMatch 10 isNewMLSMessageNotif ws - msgData <- notif %. "payload.0.data" & asByteString - msg <- showMessage alice1 msgData + msg <- consumeMessage cid Nothing ws msg %. "message.content.body.Proposal.Remove.removed" `shouldMatchInt` idxCharlie1 msg %. "message.content.sender.External" `shouldMatchInt` 0 - consumeMessage1 cid msgData -- a member commits the pending proposal void $ createPendingProposalCommit (head others') >>= sendAndConsumeCommitBundle diff --git a/integration/test/Testlib/Env.hs b/integration/test/Testlib/Env.hs index 40fd56adc8a..4a9b680be80 100644 --- a/integration/test/Testlib/Env.hs +++ b/integration/test/Testlib/Env.hs @@ -136,5 +136,6 @@ mkMLSState = Codensity $ \k -> convId = Nothing, clientGroupState = mempty, epoch = 0, - ciphersuite = def + ciphersuite = def, + protocol = MLSProtocolMLS } diff --git a/integration/test/Testlib/JSON.hs b/integration/test/Testlib/JSON.hs index 5aeba81073e..debb96fbac7 100644 --- a/integration/test/Testlib/JSON.hs +++ b/integration/test/Testlib/JSON.hs @@ -3,6 +3,7 @@ module Testlib.JSON where import Control.Monad import Control.Monad.Catch import Control.Monad.IO.Class +import Control.Monad.Trans.Class import Control.Monad.Trans.Maybe import Data.Aeson hiding ((.=)) import Data.Aeson qualified as Aeson @@ -17,6 +18,7 @@ import Data.Foldable import Data.Function import Data.Functor import Data.List.Split (splitOn) +import Data.Maybe import Data.Scientific qualified as Sci import Data.String import Data.Text qualified as T @@ -69,6 +71,14 @@ noValue = Nothing (.=?) :: ToJSON a => String -> Maybe a -> Maybe Aeson.Pair (.=?) k v = (Aeson..=) (fromString k) <$> v +-- | Convert JSON null to Nothing. +asOptional :: HasCallStack => MakesValue a => a -> App (Maybe Value) +asOptional x = do + v <- make x + pure $ case v of + Null -> Nothing + _ -> Just v + asString :: HasCallStack => MakesValue a => a -> App String asString x = make x >>= \case @@ -360,17 +370,14 @@ objDomain x = do -- is also supported. objSubConv :: (HasCallStack, MakesValue a) => a -> App (Value, Maybe String) objSubConv x = do - mParent <- lookupField x "parent_qualified_id" - case mParent of - Nothing -> do - obj <- objQidObject x - subValue <- lookupField x "subconv_id" - sub <- traverse asString subValue - pure (obj, sub) - Just parent -> do - obj <- objQidObject parent - sub <- x %. "subconv_id" & asString - pure (obj, Just sub) + v <- make x + mParent <- lookupField v "parent_qualified_id" + obj <- objQidObject $ fromMaybe v mParent + sub <- runMaybeT $ do + sub <- MaybeT $ lookupField v "subconv_id" + sub' <- MaybeT $ asOptional sub + lift $ asString sub' + pure (obj, sub) -- | Turn an object parseable by 'objSubConv' into a canonical flat representation. objSubConvObject :: (HasCallStack, MakesValue a) => a -> App Value diff --git a/integration/test/Testlib/Types.hs b/integration/test/Testlib/Types.hs index ca5ac7043f5..557c5d327d4 100644 --- a/integration/test/Testlib/Types.hs +++ b/integration/test/Testlib/Types.hs @@ -215,6 +215,9 @@ data ClientGroupState = ClientGroupState } deriving (Show) +data MLSProtocol = MLSProtocolMLS | MLSProtocolMixed + deriving (Eq, Show) + data MLSState = MLSState { baseDir :: FilePath, members :: Set ClientIdentity, @@ -224,7 +227,8 @@ data MLSState = MLSState convId :: Maybe Value, clientGroupState :: Map ClientIdentity ClientGroupState, epoch :: Word64, - ciphersuite :: Ciphersuite + ciphersuite :: Ciphersuite, + protocol :: MLSProtocol } deriving (Show) @@ -290,7 +294,7 @@ appToIOKleisli k = do getServiceMap :: HasCallStack => String -> App ServiceMap getServiceMap fedDomain = do env <- ask - assertJust ("Could not find service map for federation domain: " <> fedDomain) (Map.lookup fedDomain (env.serviceMap)) + assertJust ("Could not find service map for federation domain: " <> fedDomain) (Map.lookup fedDomain env.serviceMap) getMLSState :: App MLSState getMLSState = do