Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consume MLS messages from websocket #3671

Merged
merged 4 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/5-internal/mls-robust-consume
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
More robust consuming of MLS messages: the behaviour of `sendAndConsumeMessage` and `sendAndConsumeCommitBundle` is changed to actually wait for those messages on the client's websocket
1 change: 1 addition & 0 deletions integration/test/API/Galley.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE OverloadedLabels #-}
{-# OPTIONS_GHC -Wno-ambiguous-fields #-}

module API.Galley where

Expand Down
173 changes: 128 additions & 45 deletions integration/test/MLS/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import API.Galley
import Control.Concurrent.Async hiding (link)
import Control.Monad
import Control.Monad.Catch
import Control.Monad.Codensity
import Control.Monad.Cont
import Control.Monad.Reader
import Control.Monad.Trans.Maybe
import Data.Aeson qualified as A
import Data.Aeson qualified as Aeson
import Data.ByteString qualified as BS
import Data.ByteString.Base64 qualified as Base64
Expand All @@ -26,6 +28,7 @@ import Data.Traversable
import Data.UUID qualified as UUID
import Data.UUID.V4 qualified as UUIDV4
import GHC.Stack
import Notifications
import System.Directory
import System.Exit
import System.FilePath
Expand Down Expand Up @@ -501,15 +504,75 @@ createExternalCommit cid mgi = do
groupInfo = Just newPgs
}

-- | Make all member clients consume a given message.
consumeMessage :: HasCallStack => MessagePackage -> App ()
consumeMessage msg = do
data MLSNotificationTag = MLSNotificationMessageTag | MLSNotificationWelcomeTag
deriving (Show, Eq, Ord)

-- | Extract a conversation ID (including an optional subconversation) from an
-- event object.
eventSubConv :: HasCallStack => MakesValue event => event -> App Value
eventSubConv event = do
sub <- lookupField event "subconv"
conv <- event %. "qualified_conversation"
objSubConvObject $
object
[ "parent_qualified_id" .= conv,
"subconv_id" .= sub
]

consumingMessages :: HasCallStack => MessagePackage -> Codensity App ()
consumingMessages mp = Codensity $ \k -> do
mls <- getMLSState
for_ (Set.delete msg.sender mls.members) $ \cid ->
consumeMessage1 cid msg.message
-- clients that should receive the message itself
let oldClients = Set.delete mp.sender mls.members
-- clients that should receive a welcome message
let newClients = Set.delete mp.sender mls.newMembers
-- all clients that should receive some MLS notification, together with the
-- expected notification tag
let clients =
map (,MLSNotificationMessageTag) (toList oldClients)
<> map (,MLSNotificationWelcomeTag) (toList newClients)

let newUsers =
Set.delete mp.sender.user $
Set.difference
(Set.map (.user) newClients)
(Set.map (.user) oldClients)
withWebSockets (map fst clients) $ \wss -> do
r <- k ()

-- if the conversation is actually MLS (and not mixed), pick one client for
-- each new user and wait for its join event
when (mls.protocol == MLSProtocolMLS) $
traverse_
(awaitMatch 10 isMemberJoinNotif)
( flip Map.restrictKeys newUsers
. Map.mapKeys ((.user) . fst)
. Map.fromList
. toList
$ zip clients wss
)

-- at this point we know that every new user has been added to the
-- conversation
for_ (zip clients wss) $ \((cid, t), ws) -> case t of
MLSNotificationMessageTag -> void $ consumeMessage cid (Just mp) ws
MLSNotificationWelcomeTag -> consumeWelcome cid mp ws
pure r

-- | Get a single MLS message from a websocket and consume it. Return a JSON
-- representation of the message.
consumeMessage :: HasCallStack => ClientIdentity -> Maybe MessagePackage -> WebSocket -> App Value
consumeMessage cid mmp ws = do
mls <- getMLSState
notif <- awaitMatch 10 isNewMLSMessageNotif ws
event <- notif %. "payload.0"

for_ mmp $ \mp -> do
shouldMatch (eventSubConv event) (fromMaybe A.Null mls.convId)
shouldMatch (event %. "from") mp.sender.user
shouldMatch (event %. "data") (B8.unpack (Base64.encode mp.message))

consumeMessage1 :: HasCallStack => ClientIdentity -> ByteString -> App ()
consumeMessage1 cid msg =
msgData <- event %. "data" & asByteString
void $
mlscli
cid
Expand All @@ -520,52 +583,72 @@ consumeMessage1 cid msg =
"<group-out>",
"-"
]
(Just msg)
(Just msgData)
showMessage cid msgData

-- | Send an MLS message and simulate clients receiving it. If the message is a
-- commit, the 'sendAndConsumeCommit' function should be used instead.
-- | Send an MLS message, wait for clients to receive it, then consume it on
-- the client side. If the message is a commit, the
-- 'sendAndConsumeCommitBundle' function should be used instead.
sendAndConsumeMessage :: HasCallStack => MessagePackage -> App Value
sendAndConsumeMessage mp = do
r <- postMLSMessage mp.sender mp.message >>= getJSON 201
consumeMessage mp
pure r
sendAndConsumeMessage mp = lowerCodensity $ do
consumingMessages mp
lift $ postMLSMessage mp.sender mp.message >>= getJSON 201

-- | Send an MLS commit bundle, simulate clients receiving it, and update the
-- test state accordingly.
-- | Send an MLS commit bundle, wait for clients to receive it, consume it, and
-- update the test state accordingly.
sendAndConsumeCommitBundle :: HasCallStack => MessagePackage -> App Value
sendAndConsumeCommitBundle mp = do
resp <- postMLSCommitBundle mp.sender (mkBundle mp) >>= getJSON 201
consumeMessage mp
traverse_ consumeWelcome mp.welcome

-- increment epoch and add new clients
modifyMLSState $ \mls ->
mls
{ epoch = epoch mls + 1,
members = members mls <> newMembers mls,
newMembers = mempty
}
lowerCodensity $ do
consumingMessages mp
lift $ do
r <- postMLSCommitBundle mp.sender (mkBundle mp) >>= getJSON 201

-- if the sender is a new member (i.e. it's an external commit), then
-- process the welcome message directly
do
mls <- getMLSState
when (Set.member mp.sender mls.newMembers) $
traverse_ (fromWelcome mp.sender) mp.welcome

-- increment epoch and add new clients
modifyMLSState $ \mls ->
mls
{ epoch = epoch mls + 1,
members = members mls <> newMembers mls,
newMembers = mempty
}

pure resp
pure r

consumeWelcome :: HasCallStack => ByteString -> App ()
consumeWelcome welcome = do
consumeWelcome :: HasCallStack => ClientIdentity -> MessagePackage -> WebSocket -> App ()
consumeWelcome cid mp ws = do
mls <- getMLSState
for_ mls.newMembers $ \cid -> do
gs <- getClientGroupState cid
assertBool
"Existing clients in a conversation should not consume welcomes"
(isNothing gs.group)
void $
mlscli
cid
[ "group",
"from-welcome",
"--group-out",
"<group-out>",
"-"
]
(Just welcome)
notif <- awaitMatch 10 isWelcomeNotif ws
event <- notif %. "payload.0"

shouldMatch (eventSubConv event) (fromMaybe A.Null mls.convId)
shouldMatch (event %. "from") mp.sender.user
shouldMatch (event %. "data") (fmap (B8.unpack . Base64.encode) mp.welcome)

welcome <- event %. "data" & asByteString
gs <- getClientGroupState cid
assertBool
"Existing clients in a conversation should not consume welcomes"
(isNothing gs.group)
fromWelcome cid welcome

fromWelcome :: ClientIdentity -> ByteString -> App ()
fromWelcome cid welcome =
void $
mlscli
cid
[ "group",
"from-welcome",
"--group-out",
"<group-out>",
"-"
]
(Just welcome)

readWelcome :: FilePath -> IO (Maybe ByteString)
readWelcome fp = runMaybeT $ do
Expand Down
3 changes: 3 additions & 0 deletions integration/test/Notifications.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ isNewMessageNotif n = fieldEquals n "payload.0.type" "conversation.otr-message-a
isNewMLSMessageNotif :: MakesValue a => a -> App Bool
isNewMLSMessageNotif n = fieldEquals n "payload.0.type" "conversation.mls-message-add"

isWelcomeNotif :: MakesValue a => a -> App Bool
isWelcomeNotif n = fieldEquals n "payload.0.type" "conversation.mls-welcome"

isMemberJoinNotif :: MakesValue a => a -> App Bool
isMemberJoinNotif n = fieldEquals n "payload.0.type" "conversation.member-join"

Expand Down
2 changes: 2 additions & 0 deletions integration/test/SetupHelpers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ simpleMixedConversationSetup secondDomain = do
bindResponse (putConversationProtocol bob conv "mixed") $ \resp -> do
resp.status `shouldMatchInt` 200

modifyMLSState $ \mls -> mls {protocol = MLSProtocolMixed}

conv' <- getConversation alice conv >>= getJSON 200

pure (alice, bob, conv')
Expand Down
69 changes: 24 additions & 45 deletions integration/test/Test/MLS.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns -Wno-ambiguous-fields #-}

module Test.MLS where

Expand Down Expand Up @@ -90,6 +90,7 @@ testMixedProtocolUpgrade secondDomain = do
resp.status `shouldMatchInt` 200
resp.json %. "conversation" `shouldMatch` (qcnv %. "id")
resp.json %. "data.protocol" `shouldMatch` "mixed"
modifyMLSState $ \mls -> mls {protocol = MLSProtocolMixed}

for_ websockets $ \ws -> do
n <- awaitMatch 3 (\value -> nPayload value %. "type" `isEqual` "conversation.protocol-update") ws
Expand Down Expand Up @@ -130,6 +131,7 @@ testMixedProtocolAddUsers secondDomain = do

bindResponse (putConversationProtocol bob qcnv "mixed") $ \resp -> do
resp.status `shouldMatchInt` 200
modifyMLSState $ \mls -> mls {protocol = MLSProtocolMixed}

[alice1, bob1] <- traverse (createMLSClient def) [alice, bob]

Expand Down Expand Up @@ -158,6 +160,7 @@ testMixedProtocolUserLeaves secondDomain = do

bindResponse (putConversationProtocol bob qcnv "mixed") $ \resp -> do
resp.status `shouldMatchInt` 200
modifyMLSState $ \mls -> mls {protocol = MLSProtocolMixed}

[alice1, bob1] <- traverse (createMLSClient def) [alice, bob]

Expand Down Expand Up @@ -193,6 +196,7 @@ testMixedProtocolAddPartialClients secondDomain = do

bindResponse (putConversationProtocol bob qcnv "mixed") $ \resp -> do
resp.status `shouldMatchInt` 200
modifyMLSState $ \mls -> mls {protocol = MLSProtocolMixed}

[alice1, bob1, bob2] <- traverse (createMLSClient def) [alice, bob, bob]

Expand Down Expand Up @@ -231,6 +235,7 @@ testMixedProtocolRemovePartialClients secondDomain = do

bindResponse (putConversationProtocol bob qcnv "mixed") $ \resp -> do
resp.status `shouldMatchInt` 200
modifyMLSState $ \mls -> mls {protocol = MLSProtocolMixed}

[alice1, bob1, bob2] <- traverse (createMLSClient def) [alice, bob, bob]

Expand All @@ -256,6 +261,7 @@ testMixedProtocolAppMessagesAreDenied secondDomain = do

bindResponse (putConversationProtocol bob qcnv "mixed") $ \resp -> do
resp.status `shouldMatchInt` 200
modifyMLSState $ \mls -> mls {protocol = MLSProtocolMixed}

[alice1, bob1] <- traverse (createMLSClient def) [alice, bob]

Expand Down Expand Up @@ -302,6 +308,7 @@ testMLSProtocolUpgrade secondDomain = do
withWebSockets [alice1, bob1] $ \wss -> do
bindResponse (putConversationProtocol bob conv "mls") $ \resp -> do
resp.status `shouldMatchInt` 200
modifyMLSState $ \mls -> mls {protocol = MLSProtocolMLS}
for_ wss $ \ws -> do
n <- awaitMatch 3 isNewMLSMessageNotif ws
msg <- asByteString (nPayload n %. "data") >>= showMessage alice1
Expand Down Expand Up @@ -399,47 +406,17 @@ testCreateSubConvProteus = do
bindResponse (getSubConversation alice conv "conference") $ \resp ->
resp.status `shouldMatchInt` 404

-- FUTUREWORK: New clients should be adding themselves via external commits, and
-- they shouldn't be added by another client. Change the test so external
-- commits are used.
testSelfConversation :: App ()
testSelfConversation = do
alice <- randomUser OwnDomain def
creator : others <- traverse (createMLSClient def) (replicate 3 alice)
traverse_ uploadNewKeyPackage others
(_, cnv) <- createSelfGroup creator
commit <- createAddCommit creator [alice]
welcome <- assertOne (toList commit.welcome)
void $ createSelfGroup creator
void $ createAddCommit creator [alice] >>= sendAndConsumeCommitBundle

withWebSockets others $ \wss -> do
void $ sendAndConsumeCommitBundle commit
let isWelcome n = nPayload n %. "type" `isEqual` "conversation.mls-welcome"
for_ wss $ \ws -> do
n <- awaitMatch 3 isWelcome ws
shouldMatch (nPayload n %. "conversation") (objId cnv)
shouldMatch (nPayload n %. "from") (objId alice)
shouldMatch (nPayload n %. "data") (B8.unpack (Base64.encode welcome))

testJoinSubConv :: App ()
testJoinSubConv = do
[alice, bob] <- createAndConnectUsers [OwnDomain, OwnDomain]
[alice1, bob1, bob2] <- traverse (createMLSClient def) [alice, bob, bob]
traverse_ uploadNewKeyPackage [bob1, bob2]
(_, qcnv) <- createNewGroup alice1
void $ createAddCommit alice1 [bob] >>= sendAndConsumeCommitBundle
void $ createSubConv bob1 "conference"

-- bob adds his first client to the subconversation
void $ createPendingProposalCommit bob1 >>= sendAndConsumeCommitBundle
sub' <- getSubConversation bob qcnv "conference" >>= getJSON 200
do
tm <- sub' %. "epoch_timestamp"
assertBool "Epoch timestamp should not be null" (tm /= Null)

-- now alice joins with her own client
void $
createExternalCommit alice1 Nothing
>>= sendAndConsumeCommitBundle
newClient <- createMLSClient def alice
void $ uploadNewKeyPackage newClient
void $ createExternalCommit newClient Nothing >>= sendAndConsumeCommitBundle

-- | FUTUREWORK: Don't allow partial adds, not even in the first commit
testFirstCommitAllowsPartialAdds :: HasCallStack => App ()
Expand Down Expand Up @@ -505,6 +482,7 @@ testAdminRemovesUserFromConv :: HasCallStack => App ()
testAdminRemovesUserFromConv = do
[alice, bob] <- createAndConnectUsers [OwnDomain, OwnDomain]
[alice1, bob1, bob2] <- traverse (createMLSClient def) [alice, bob, bob]

void $ createWireClient bob
traverse_ uploadNewKeyPackage [bob1, bob2]
(gid, qcnv) <- createNewGroup alice1
Expand All @@ -520,15 +498,16 @@ testAdminRemovesUserFromConv = do
bobQid <- bob %. "qualified_id"
shouldMatch members [bobQid]

convs <- getAllConvs bob
convIds <- traverse (%. "qualified_id") convs
clients <- bindResponse (getGroupClients alice gid) $ \resp -> do
resp.status `shouldMatchInt` 200
resp.json %. "client_ids" & asList
void $ assertOne clients
assertBool
"bob is not longer part of conversation after the commit"
(qcnv `notElem` convIds)
do
convs <- getAllConvs bob
convIds <- traverse (%. "qualified_id") convs
clients <- bindResponse (getGroupClients alice gid) $ \resp -> do
resp.status `shouldMatchInt` 200
resp.json %. "client_ids" & asList
void $ assertOne clients
assertBool
"bob is not longer part of conversation after the commit"
(qcnv `notElem` convIds)

testLocalWelcome :: HasCallStack => App ()
testLocalWelcome = do
Expand Down
Loading