Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WPB-4981] replace unclaimed keypackages atomically #3654

Merged
merged 5 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/1-api-changes/mls-replace-kps
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
New endpoint for replacing MLS key packages in bulk: `PUT /mls/key-packages/self/:client`. It replaces all existing key packages that match the given ciphersuites with the new key packages provided in the body.
10 changes: 10 additions & 0 deletions integration/test/API/Brig.hs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,16 @@ deleteKeyPackages cid kps = do
req <- baseRequest cid Brig Versioned ("/mls/key-packages/self/" <> cid.client)
submit "DELETE" $ req & addJSONObject ["key_packages" .= kps]

replaceKeyPackages :: ClientIdentity -> [Ciphersuite] -> [ByteString] -> App Response
replaceKeyPackages cid suites kps = do
req <-
baseRequest cid Brig Versioned $
"/mls/key-packages/self/" <> cid.client
submit "PUT" $
req
& addQueryParams [("ciphersuites", intercalate "," (map (.code) suites))]
& addJSONObject ["key_packages" .= map (T.decodeUtf8 . Base64.encode) kps]

getSelf :: HasCallStack => String -> String -> App Response
getSelf domain uid = do
let user = object ["domain" .= domain, "id" .= uid]
Expand Down
10 changes: 5 additions & 5 deletions integration/test/MLS/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -174,18 +174,18 @@ createMLSClient opts u = do
-- | create and upload to backend
uploadNewKeyPackage :: HasCallStack => ClientIdentity -> App String
uploadNewKeyPackage cid = do
(kp, ref) <- generateKeyPackage cid
mls <- getMLSState
(kp, ref) <- generateKeyPackage cid mls.ciphersuite

-- upload key package
bindResponse (uploadKeyPackages cid [kp]) $ \resp ->
resp.status `shouldMatchInt` 201

pure ref

generateKeyPackage :: HasCallStack => ClientIdentity -> App (ByteString, String)
generateKeyPackage cid = do
mls <- getMLSState
kp <- mlscli cid ["key-package", "create", "--ciphersuite", mls.ciphersuite.code] Nothing
generateKeyPackage :: HasCallStack => ClientIdentity -> Ciphersuite -> App (ByteString, String)
generateKeyPackage cid suite = do
kp <- mlscli cid ["key-package", "create", "--ciphersuite", suite.code] Nothing
ref <- B8.unpack . Base64.encode <$> mlscli cid ["key-package", "ref", "-"] (Just kp)
fp <- keyPackageFile cid ref
liftIO $ BS.writeFile fp kp
Expand Down
77 changes: 72 additions & 5 deletions integration/test/Test/MLS/KeyPackage.hs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@ testKeyPackageCount cs = do
resp.status `shouldMatchInt` 200
resp.json %. "count" `shouldMatchInt` 0

setMLSCiphersuite cs

let count = 10
kps <- map fst <$> replicateM count (generateKeyPackage alice1)
kps <- map fst <$> replicateM count (generateKeyPackage alice1 cs)
void $ uploadKeyPackages alice1 kps >>= getBody 201

bindResponse (countKeyPackages cs alice1) $ \resp -> do
Expand All @@ -68,10 +66,79 @@ testKeyPackageCount cs = do

testUnsupportedCiphersuite :: HasCallStack => App ()
testUnsupportedCiphersuite = do
setMLSCiphersuite (Ciphersuite "0x0002")
let suite = Ciphersuite "0x0002"
setMLSCiphersuite suite
bob <- randomUser OwnDomain def
bob1 <- createMLSClient def bob
(kp, _) <- generateKeyPackage bob1
(kp, _) <- generateKeyPackage bob1 suite
bindResponse (uploadKeyPackages bob1 [kp]) $ \resp -> do
resp.status `shouldMatchInt` 400
resp.json %. "label" `shouldMatch` "mls-protocol-error"

testReplaceKeyPackages :: HasCallStack => App ()
testReplaceKeyPackages = do
alice <- randomUser OwnDomain def
[alice1, alice2] <- replicateM 2 $ createMLSClient def alice
let suite = Ciphersuite "0xf031"

let checkCount cs n =
bindResponse (countKeyPackages cs alice1) $ \resp -> do
resp.status `shouldMatchInt` 200
resp.json %. "count" `shouldMatchInt` n

-- setup: upload a batch of key packages for each ciphersuite
void $
replicateM 4 (fmap fst (generateKeyPackage alice1 def))
>>= uploadKeyPackages alice1
>>= getBody 201
setMLSCiphersuite suite
void $
replicateM 5 (fmap fst (generateKeyPackage alice1 suite))
>>= uploadKeyPackages alice1
>>= getBody 201

checkCount def 4
checkCount suite 5

do
-- generate a new batch of key packages
(kps, refs) <- unzip <$> replicateM 3 (generateKeyPackage alice1 suite)

-- replace old key packages with new
void $ replaceKeyPackages alice1 [suite] kps >>= getBody 201

checkCount def 4
checkCount suite 3

-- claim all key packages one by one
claimed <-
replicateM 3 $
bindResponse (claimKeyPackages suite alice2 alice) $ \resp -> do
resp.status `shouldMatchInt` 200
ks <- resp.json %. "key_packages" & asList
k <- assertOne ks
k %. "key_package_ref"

refs `shouldMatchSet` claimed

checkCount def 4
checkCount suite 0

when False $ do
MangoIV marked this conversation as resolved.
Show resolved Hide resolved
-- replenish key packages for the second ciphersuite
void $
replicateM 5 (fmap fst (generateKeyPackage alice1 suite))
>>= uploadKeyPackages alice1
>>= getBody 201

checkCount def 4
checkCount suite 5

-- replace all key packages with fresh ones
kps1 <- replicateM 2 (fmap fst (generateKeyPackage alice1 def))
kps2 <- replicateM 2 (fmap fst (generateKeyPackage alice1 suite))

void $ replaceKeyPackages alice1 [def, suite] (kps1 <> kps2) >>= getBody 201

checkCount def 2
checkCount suite 2
17 changes: 11 additions & 6 deletions libs/wire-api/src/Wire/API/MLS/CipherSuite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ module Wire.API.MLS.CipherSuite
where

import Cassandra.CQL
import Control.Applicative
import Control.Error (note)
import Control.Lens ((?~))
import Crypto.Error
Expand All @@ -51,18 +52,20 @@ import Crypto.PubKey.Ed25519 qualified as Ed25519
import Data.Aeson qualified as Aeson
import Data.Aeson.Types (FromJSON (..), FromJSONKey (..), ToJSON (..), ToJSONKey (..))
import Data.Aeson.Types qualified as Aeson
import Data.Attoparsec.ByteString.Char8 qualified as Atto
import Data.Bifunctor
import Data.ByteArray hiding (index)
import Data.ByteArray qualified as BA
import Data.ByteString.Conversion
import Data.OpenApi qualified as S
import Data.OpenApi.Internal.Schema qualified as S
import Data.Proxy
import Data.Schema
import Data.Text qualified as T
import Data.Text.Encoding qualified as T
import Data.Text.Lazy qualified as LT
import Data.Text.Lazy.Builder qualified as LT
import Data.Text.Lazy.Builder.Int qualified as LT
import Data.Text.Read qualified as T
import Data.Word
import Imports hiding (cs)
import Web.HttpApiData
Expand All @@ -85,11 +88,8 @@ instance S.ToParamSchema CipherSuite where
& S.type_ ?~ S.OpenApiNumber

instance FromHttpApiData CipherSuite where
parseUrlPiece t = do
(x, rest) <- first T.pack $ T.hexadecimal t
unless (T.null rest) $
Left "Trailing characters after ciphersuite number"
pure (CipherSuite x)
parseUrlPiece = parseHeader . T.encodeUtf8
parseHeader = first T.pack . runParser parser

instance ToHttpApiData CipherSuite where
toUrlPiece =
Expand All @@ -99,6 +99,11 @@ instance ToHttpApiData CipherSuite where
. LT.hexadecimal
. cipherSuiteNumber

instance FromByteString CipherSuite where
parser = do
void $ Atto.try (optional (Atto.string "0x"))
CipherSuite <$> Atto.hexadecimal

data CipherSuiteTag
= MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519
| MLS_128_X25519Kyber768Draft00_AES128GCM_SHA256_Ed25519
Expand Down
3 changes: 2 additions & 1 deletion libs/wire-api/src/Wire/API/MLS/KeyPackage.hs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ import Wire.API.MLS.Serialisation
import Wire.Arbitrary

data KeyPackageUpload = KeyPackageUpload
{keyPackages :: [RawMLS KeyPackage]}
{ keyPackages :: [RawMLS KeyPackage]
}
deriving (FromJSON, ToJSON, S.ToSchema) via Schema KeyPackageUpload

instance ToSchema KeyPackageUpload where
Expand Down
23 changes: 23 additions & 0 deletions libs/wire-api/src/Wire/API/Routes/Public/Brig.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,15 @@ type CipherSuiteParam =
"ciphersuite"
CipherSuite

type MultipleCipherSuitesParam =
QueryParam'
[ Optional,
Strict,
Description "Comma-separated list of ciphersuites in hex format (e.g. 0xf031) - default is 0x0001"
]
"ciphersuites"
(CommaSeparatedList CipherSuite)

type MLSKeyPackageAPI =
"key-packages"
:> ( Named
Expand All @@ -1201,6 +1210,20 @@ type MLSKeyPackageAPI =
:> ReqBody '[JSON] KeyPackageUpload
:> MultiVerb 'POST '[JSON, MLS] '[RespondEmpty 201 "Key packages uploaded"] ()
)
:<|> Named
"mls-key-packages-replace"
( "self"
:> Summary "Upload a fresh batch of key packages and replace the old ones"
:> From 'V5
:> Description "The request body should be a json object containing a list of base64-encoded key packages. Use this sparingly."
:> ZLocalUser
:> CanThrow 'MLSProtocolError
:> CanThrow 'MLSIdentityMismatch
:> CaptureClientId "client"
:> MultipleCipherSuitesParam
:> ReqBody '[JSON] KeyPackageUpload
:> MultiVerb 'PUT '[JSON, MLS] '[RespondEmpty 201 "Key packages replaced"] ()
)
:<|> Named
"mls-key-packages-claim"
( "claim"
Expand Down
17 changes: 12 additions & 5 deletions services/brig/src/Brig/API/MLS/CipherSuite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,22 @@
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Brig.API.MLS.CipherSuite (getCipherSuite) where
module Brig.API.MLS.CipherSuite (getCipherSuite, getCipherSuites) where

import Brig.API.Handler
import Brig.API.MLS.KeyPackages.Validation
import Imports
import Wire.API.MLS.CipherSuite

getOneCipherSuite :: CipherSuite -> Handler r CipherSuiteTag
getOneCipherSuite s =
maybe
(mlsProtocolError "Unknown ciphersuite")
pure
(cipherSuiteTag s)

getCipherSuite :: Maybe CipherSuite -> Handler r CipherSuiteTag
getCipherSuite mSuite = case mSuite of
Nothing -> pure defCipherSuite
Just x ->
maybe (mlsProtocolError "Unknown ciphersuite") pure (cipherSuiteTag x)
getCipherSuite = maybe (pure defCipherSuite) getOneCipherSuite

getCipherSuites :: Maybe [CipherSuite] -> Handler r [CipherSuiteTag]
getCipherSuites = maybe (pure [defCipherSuite]) (traverse getOneCipherSuite)
14 changes: 14 additions & 0 deletions services/brig/src/Brig/API/MLS/KeyPackages.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ module Brig.API.MLS.KeyPackages
claimLocalKeyPackages,
countKeyPackages,
deleteKeyPackages,
replaceKeyPackages,
)
where

Expand All @@ -37,6 +38,7 @@ import Brig.Federation.Client
import Brig.IO.Intra
import Control.Monad.Trans.Except
import Control.Monad.Trans.Maybe
import Data.CommaSeparatedList
import Data.Id
import Data.Qualified
import Data.Set qualified as Set
Expand Down Expand Up @@ -157,3 +159,15 @@ deleteKeyPackages lusr c mSuite (unDeleteKeyPackages -> refs) = do
assertMLSEnabled
suite <- getCipherSuite mSuite
lift $ wrapClient (Data.deleteKeyPackages (tUnqualified lusr) c suite refs)

replaceKeyPackages ::
Local UserId ->
ClientId ->
Maybe (CommaSeparatedList CipherSuite) ->
KeyPackageUpload ->
Handler r ()
replaceKeyPackages lusr c (fmap toList -> mSuites) upload = do
assertMLSEnabled
suites <- getCipherSuites mSuites
lift $ wrapClient (Data.deleteAllKeyPackages (tUnqualified lusr) c suites)
uploadKeyPackages lusr c upload
1 change: 1 addition & 0 deletions services/brig/src/Brig/API/Public.hs
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ servantSitemap =
mlsAPI :: ServerT MLSAPI (Handler r)
mlsAPI =
Named @"mls-key-packages-upload" uploadKeyPackages
:<|> Named @"mls-key-packages-replace" replaceKeyPackages
:<|> Named @"mls-key-packages-claim" claimKeyPackages
:<|> Named @"mls-key-packages-count" countKeyPackages
:<|> Named @"mls-key-packages-delete" deleteKeyPackages
Expand Down
18 changes: 18 additions & 0 deletions services/brig/src/Brig/Data/MLS/KeyPackage.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ module Brig.Data.MLS.KeyPackage
claimKeyPackage,
countKeyPackages,
deleteKeyPackages,
deleteAllKeyPackages,
)
where

Expand All @@ -37,6 +38,7 @@ import Data.Qualified
import Data.Time.Clock
import Data.Time.Clock.POSIX
import Imports
import UnliftIO.Async
import Wire.API.MLS.CipherSuite
import Wire.API.MLS.KeyPackage
import Wire.API.MLS.LeafNode
Expand Down Expand Up @@ -142,6 +144,22 @@ deleteKeyPackages u c suite refs =
deleteQuery :: PrepQuery W (UserId, ClientId, CipherSuiteTag, [KeyPackageRef]) ()
deleteQuery = "DELETE FROM mls_key_packages WHERE user = ? AND client = ? AND cipher_suite = ? AND ref in ?"

deleteAllKeyPackages ::
(MonadClient m, MonadUnliftIO m) =>
UserId ->
ClientId ->
[CipherSuiteTag] ->
m ()
deleteAllKeyPackages u c suites =
pooledForConcurrentlyN_ 16 suites $ \suite ->
MangoIV marked this conversation as resolved.
Show resolved Hide resolved
retry x5 $
write
deleteQuery
(params LocalQuorum (u, c, suite))
where
deleteQuery :: PrepQuery W (UserId, ClientId, CipherSuiteTag) ()
deleteQuery = "DELETE FROM mls_key_packages WHERE user = ? AND client = ? AND cipher_suite = ?"

--------------------------------------------------------------------------------
-- Utilities

Expand Down
Loading