Skip to content

Commit

Permalink
add conversation type to group ID serialisation (#3344)
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanwire authored Jun 15, 2023
1 parent 9ba52c5 commit 241a588
Show file tree
Hide file tree
Showing 18 changed files with 144 additions and 78 deletions.
1 change: 1 addition & 0 deletions changelog.d/5-internal/WPB-1925
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add conversation type to group ID serialisation
17 changes: 17 additions & 0 deletions integration/test/API/Galley.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
module API.Galley where

import qualified Data.Aeson as Aeson
import qualified Data.ByteString.Base64 as B64
import qualified Data.ByteString.Base64.URL as B64U
import qualified Data.ByteString.Char8 as BS
import Testlib.Prelude

data CreateConv = CreateConv
Expand Down Expand Up @@ -212,3 +215,17 @@ getMLSOne2OneConversation self other = do
baseRequest self Galley Versioned $
joinHttpPath ["conversations", "one2one", domain, uid]
submit "GET" req

getGroupClients ::
(HasCallStack, MakesValue user) =>
user ->
String ->
App Response
getGroupClients user groupId = do
req <-
baseRequest
user
Galley
Unversioned
(joinHttpPath ["i", "group", BS.unpack . B64U.encodeUnpadded . B64.decodeLenient $ BS.pack groupId])
submit "GET" req
29 changes: 29 additions & 0 deletions integration/test/Test/MLS.hs
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,32 @@ testRemoveClientsIncomplete = do

err <- postMLSCommitBundle mp.sender (mkBundle mp) >>= getJSON 409
err %. "label" `shouldMatch` "mls-client-mismatch"

testAdminRemovesUserFromConv :: HasCallStack => App ()
testAdminRemovesUserFromConv = do
[alice, bob] <- createAndConnectUsers [OwnDomain, OwnDomain]
[alice1, bob1, bob2] <- traverse createMLSClient [alice, bob, bob]
void $ createWireClient bob
traverse_ uploadNewKeyPackage [bob1, bob2]
(gid, qcnv) <- createNewGroup alice1
void $ createAddCommit alice1 [bob] >>= sendAndConsumeCommitBundle
events <- createRemoveCommit alice1 [bob1, bob2] >>= sendAndConsumeCommitBundle

do
event <- assertOne =<< asList (events %. "events")
event %. "qualified_conversation" `shouldMatch` qcnv
event %. "type" `shouldMatch` "conversation.member-leave"
event %. "from" `shouldMatch` objId alice
members <- event %. "data" %. "qualified_user_ids" & asList
bobQid <- bob %. "qualified_id"
shouldMatch members [bobQid]

convs <- getAllConvs bob
convIds <- traverse (%. "qualified_id") convs
clients <- bindResponse (getGroupClients alice gid) $ \resp -> do
resp.status `shouldMatchInt` 200
resp.json %. "client_ids" & asList
void $ assertOne clients
assertBool
"bob is not longer part of conversation after the commit"
(qcnv `notElem` convIds)
2 changes: 1 addition & 1 deletion libs/wire-api/src/Wire/API/Conversation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ data ConvType
| SelfConv
| One2OneConv
| ConnectConv
deriving stock (Eq, Show, Generic)
deriving stock (Eq, Show, Enum, Generic)
deriving (Arbitrary) via (GenericUniform ConvType)
deriving (FromJSON, ToJSON, S.ToSchema) via Schema ConvType

Expand Down
2 changes: 2 additions & 0 deletions libs/wire-api/src/Wire/API/MLS/Group.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ import Data.Json.Util
import Data.Schema
import qualified Data.Swagger as S
import Imports
import Servant
import Wire.API.MLS.Serialisation
import Wire.Arbitrary

newtype GroupId = GroupId {unGroupId :: ByteString}
deriving (Eq, Show, Generic, Ord)
deriving (Arbitrary) via (GenericUniform GroupId)
deriving (FromHttpApiData, ToHttpApiData, S.ToParamSchema) via Base64ByteString
deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema GroupId)

instance IsString GroupId where
Expand Down
61 changes: 42 additions & 19 deletions libs/wire-api/src/Wire/API/MLS/Group/Serialisation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Wire.API.MLS.Group.Serialisation
( convToGroupId,
convToGroupId',
( GroupIdParts (..),
groupIdParts,
convToGroupId,
groupIdToConv,
nextGenGroupId,
)
Expand All @@ -36,47 +37,69 @@ import qualified Data.Text.Encoding as T
import qualified Data.UUID as UUID
import Imports hiding (cs)
import Web.HttpApiData (FromHttpApiData (parseHeader))
import Wire.API.Conversation
import Wire.API.MLS.Group
import Wire.API.MLS.SubConversation

data GroupIdParts = GroupIdParts
{ convType :: ConvType,
qConvId :: Qualified ConvOrSubConvId,
gidGen :: GroupIdGen
}
deriving (Show, Eq)

groupIdParts :: ConvType -> Qualified ConvOrSubConvId -> GroupIdParts
groupIdParts ct qcs =
GroupIdParts
{ convType = ct,
qConvId = qcs,
gidGen = GroupIdGen 0
}

-- | Return the group ID associated to a conversation ID. Note that is not
-- assumed to be stable over time or even consistent among different backends.
convToGroupId :: Qualified ConvOrSubConvId -> GroupIdGen -> GroupId
convToGroupId qcs gen = GroupId . L.toStrict . runPut $ do
let cs = qUnqualified qcs
convToGroupId :: GroupIdParts -> GroupId
convToGroupId parts = GroupId . L.toStrict . runPut $ do
let cs = qUnqualified parts.qConvId
subId = foldMap unSubConvId cs.subconv
putWord64be 1 -- Version 1 of the GroupId format
putWord32be (fromIntegral $ fromEnum parts.convType)
putLazyByteString . UUID.toByteString . toUUID $ cs.conv
putWord8 $ fromIntegral (T.length subId)
putByteString $ T.encodeUtf8 subId
maybe (pure ()) (const $ putWord32be (unGroupIdGen gen)) cs.subconv
putLazyByteString . toByteString $ qDomain qcs
maybe (pure ()) (const $ putWord32be (unGroupIdGen parts.gidGen)) cs.subconv
putLazyByteString . toByteString $ qDomain parts.qConvId

convToGroupId' :: Qualified ConvOrSubConvId -> GroupId
convToGroupId' = flip convToGroupId (GroupIdGen 0)

groupIdToConv :: GroupId -> Either String (Qualified ConvOrSubConvId, GroupIdGen)
groupIdToConv :: GroupId -> Either String GroupIdParts
groupIdToConv gid = do
(rem', _, (conv, gen)) <- first (\(_, _, msg) -> msg) $ runGetOrFail readConv (L.fromStrict (unGroupId gid))
(rem', _, (ct, conv, gen)) <- first (\(_, _, msg) -> msg) $ runGetOrFail readConv (L.fromStrict (unGroupId gid))
domain <- first displayException . T.decodeUtf8' . L.toStrict $ rem'
pure $ (Qualified conv (Domain domain), gen)
pure
GroupIdParts
{ convType = toEnum $ fromIntegral ct,
qConvId = Qualified conv (Domain domain),
gidGen = gen
}
where
readConv = do
version <- getWord64be
ct <- getWord32be
unless (version == 1) $ fail "unsupported groupId version"
mUUID <- UUID.fromByteString . L.fromStrict <$> getByteString 16
uuid <- maybe (fail "invalid conversation UUID in groupId") pure mUUID
n <- getWord8
if n == 0
then pure $ (Conv (Id uuid), GroupIdGen 0)
then pure $ (ct, Conv (Id uuid), GroupIdGen 0)
else do
subConvIdBS <- getByteString $ fromIntegral n
subConvId <- either (fail . T.unpack) pure $ parseHeader subConvIdBS
gen <- getWord32be
pure $ (SubConv (Id uuid) (SubConvId subConvId), GroupIdGen gen)
pure $ (ct, SubConv (Id uuid) (SubConvId subConvId), GroupIdGen gen)

nextGenGroupId :: GroupId -> Either String GroupId
nextGenGroupId gid =
uncurry convToGroupId
. second (GroupIdGen . succ . unGroupIdGen)
<$> groupIdToConv gid
nextGenGroupId gid = convToGroupId . succGen <$> groupIdToConv gid
where
succGen parts =
parts
{ gidGen = GroupIdGen (succ $ unGroupIdGen parts.gidGen)
}
6 changes: 3 additions & 3 deletions libs/wire-api/src/Wire/API/Routes/Internal/Galley.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import Servant hiding (JSON, WithStatus)
import qualified Servant hiding (WithStatus)
import Servant.Swagger
import Wire.API.ApplyMods
import Wire.API.Conversation
import Wire.API.Conversation.Role
import Wire.API.Error
import Wire.API.Error.Galley
Expand Down Expand Up @@ -216,10 +217,9 @@ type InternalAPIBase =
:<|> Named
"get-conversation-clients"
( Summary "Get mls conversation client list"
:> ZLocalUser
:> CanThrow 'ConvNotFound
:> "conversation"
:> Capture "cnv" ConvId
:> "group"
:> Capture "gid" GroupId
:> MultiVerb1
'GET
'[Servant.JSON]
Expand Down
21 changes: 18 additions & 3 deletions libs/wire-api/test/unit/Test/Wire/API/MLS/Group.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import Imports
import Test.QuickCheck
import Test.Tasty
import Test.Tasty.QuickCheck
import Wire.API.Conversation
import Wire.API.MLS.Group
import Wire.API.MLS.Group.Serialisation
import Wire.API.MLS.SubConversation
Expand All @@ -33,9 +34,23 @@ tests =
[ testProperty "roundtrip serialise and parse groupId" $ roundtripGroupId
]

roundtripGroupId :: Qualified ConvOrSubConvId -> GroupIdGen -> Property
roundtripGroupId convId gen =
roundtripGroupId :: ConvType -> Qualified ConvOrSubConvId -> GroupIdGen -> Property
roundtripGroupId ct convId gen =
let gen' = case qUnqualified convId of
(Conv _) -> GroupIdGen 0
(SubConv _ _) -> gen
in groupIdToConv (convToGroupId convId gen) === Right (convId, gen')
in groupIdToConv
( convToGroupId
GroupIdParts
{ convType = ct,
qConvId = convId,
gidGen = gen
}
)
=== Right
( GroupIdParts
{ convType = ct,
qConvId = convId,
gidGen = gen'
}
)
2 changes: 1 addition & 1 deletion services/galley/src/Galley/API/Action.hs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ performAction tag origUser lconv action = do
SConversationUpdateProtocolTag -> do
case (protocolTag (convProtocol (tUnqualified lconv)), action, convTeam (tUnqualified lconv)) of
(ProtocolProteusTag, ProtocolMixedTag, Just _) -> do
E.updateToMixedProtocol lcnv MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519
E.updateToMixedProtocol lcnv (convType (tUnqualified lconv)) MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519
pure (mempty, action)
(ProtocolMixedTag, ProtocolMLSTag, Just tid) -> do
mig <- getFeatureStatus @MlsMigrationConfig DontDoAuth tid
Expand Down
9 changes: 3 additions & 6 deletions services/galley/src/Galley/API/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ import Wire.API.Event.Conversation
import Wire.API.Federation.API
import Wire.API.Federation.API.Galley
import Wire.API.Federation.Error
import Wire.API.MLS.Group.Serialisation
import Wire.API.MLS.SubConversation
import Wire.API.Provider.Service hiding (Service)
import Wire.API.Routes.API
import Wire.API.Routes.Internal.Galley
Expand Down Expand Up @@ -484,9 +482,8 @@ iGetMLSClientListForConv ::
ErrorS 'ConvNotFound
]
r =>
Local UserId ->
ConvId ->
GroupId ->
Sem r ClientList
iGetMLSClientListForConv lusr cnv = do
cm <- E.lookupMLSClients (convToGroupId' (Conv <$> tUntagged (qualifyAs lusr cnv)))
iGetMLSClientListForConv gid = do
cm <- E.lookupMLSClients gid
pure $ ClientList (concatMap (Map.keys . snd) (Map.assocs cm))
2 changes: 1 addition & 1 deletion services/galley/src/Galley/API/MLS/One2One.hs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ localMLSOne2OneConversationMetadata self convId =
)
{ cnvmType = One2OneConv
}
groupId = convToGroupId' (fmap Conv convId)
groupId = convToGroupId $ groupIdParts One2OneConv (fmap Conv convId)
mlsData =
ConversationMLSData
{ cnvmlsGroupId = groupId,
Expand Down
14 changes: 12 additions & 2 deletions services/galley/src/Galley/API/MLS/SubConversation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ getLocalSubConversation qusr lconv sconv = do

-- deriving this detemernistically to prevent race condition between
-- multiple threads creating the subconversation
let groupId = convToGroupId' $ flip SubConv sconv <$> tUntagged lconv
let groupId =
convToGroupId
. groupIdParts (Data.convType c)
$ flip SubConv sconv <$> tUntagged lconv
epoch = Epoch 0
suite = cnvmlsCipherSuite mlsMeta
Eff.createSubConversation (tUnqualified lconv) sconv suite epoch groupId Nothing
Expand Down Expand Up @@ -294,7 +297,14 @@ deleteLocalSubConversation qusr lcnvId scnvId dsc = do
Eff.removeAllMLSClients gid

-- swallowing the error and starting with GroupIdGen 0 if nextGenGroupId
let newGid = fromRight (convToGroupId' (flip SubConv scnvId <$> tUntagged lcnvId)) $ nextGenGroupId gid
let newGid =
fromRight
( convToGroupId $
groupIdParts
(Data.convType cnv)
(flip SubConv scnvId <$> tUntagged lcnvId)
)
$ nextGenGroupId gid

-- the following overwrites any prior information about the subconversation
Eff.createSubConversation cnvId scnvId cs (Epoch 0) newGid Nothing
Expand Down
2 changes: 1 addition & 1 deletion services/galley/src/Galley/API/MLS/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,4 @@ withCommitLock lConvOrSubId gid epoch action =
ttl = fromIntegral (600 :: Int) -- 10 minutes

getConvFromGroupId :: Member (Error MLSProtocolError) r => GroupId -> Sem r (Qualified ConvOrSubConvId)
getConvFromGroupId = either (throw . mlsProtocolError . T.pack) (pure . fst) . groupIdToConv
getConvFromGroupId = either (throw . mlsProtocolError . T.pack) (pure . qConvId) . groupIdToConv
11 changes: 6 additions & 5 deletions services/galley/src/Galley/Cassandra/Conversation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ createMLSSelfConversation lusr = do
ncProtocol = ProtocolCreateMLSTag
}
meta = ncMetadata nc
gid = convToGroupId' . fmap Conv . tUntagged . qualifyAs lusr $ cnv
gid = convToGroupId . groupIdParts meta.cnvmType . fmap Conv . tUntagged . qualifyAs lusr $ cnv
-- FUTUREWORK: Stop hard-coding the cipher suite
--
-- 'CipherSuite 1' corresponds to
Expand Down Expand Up @@ -123,7 +123,7 @@ createConversation lcnv nc = do
(proto, mgid, mep, mcs) = case ncProtocol nc of
ProtocolCreateProteusTag -> (ProtocolProteus, Nothing, Nothing, Nothing)
ProtocolCreateMLSTag ->
let gid = convToGroupId' $ Conv <$> tUntagged lcnv
let gid = convToGroupId . groupIdParts meta.cnvmType $ Conv <$> tUntagged lcnv
ep = Epoch 0
cs = MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519
in ( ProtocolMLS
Expand Down Expand Up @@ -412,10 +412,11 @@ updateToMixedProtocol ::
]
r =>
Local ConvId ->
ConvType ->
CipherSuiteTag ->
Sem r ()
updateToMixedProtocol lcnv cs = do
let gid = convToGroupId' $ Conv <$> tUntagged lcnv
updateToMixedProtocol lcnv ct cs = do
let gid = convToGroupId . groupIdParts ct $ Conv <$> tUntagged lcnv
epoch = Epoch 0
embedClient . retry x5 . batch $ do
setType BatchLogged
Expand Down Expand Up @@ -464,5 +465,5 @@ interpretConversationStoreToCassandra = interpret $ \case
SetGroupInfo cid gib -> embedClient $ setGroupInfo cid gib
AcquireCommitLock gId epoch ttl -> embedClient $ acquireCommitLock gId epoch ttl
ReleaseCommitLock gId epoch -> embedClient $ releaseCommitLock gId epoch
UpdateToMixedProtocol cid cs -> updateToMixedProtocol cid cs
UpdateToMixedProtocol cid ct cs -> updateToMixedProtocol cid ct cs
UpdateToMLSProtocol cid -> updateToMLSProtocol cid
2 changes: 1 addition & 1 deletion services/galley/src/Galley/Effects/ConversationStore.hs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ data ConversationStore m a where
SetGroupInfo :: ConvId -> GroupInfoData -> ConversationStore m ()
AcquireCommitLock :: GroupId -> Epoch -> NominalDiffTime -> ConversationStore m LockAcquired
ReleaseCommitLock :: GroupId -> Epoch -> ConversationStore m ()
UpdateToMixedProtocol :: Local ConvId -> CipherSuiteTag -> ConversationStore m ()
UpdateToMixedProtocol :: Local ConvId -> ConvType -> CipherSuiteTag -> ConversationStore m ()
UpdateToMLSProtocol :: Local ConvId -> ConversationStore m ()

makeSem ''ConversationStore
Expand Down
Loading

0 comments on commit 241a588

Please sign in to comment.