diff --git a/changelog.d/1-api-changes/mls-replace-kps b/changelog.d/1-api-changes/mls-replace-kps new file mode 100644 index 00000000000..671b7d4b8c0 --- /dev/null +++ b/changelog.d/1-api-changes/mls-replace-kps @@ -0,0 +1 @@ +New endpoint for replacing MLS key packages in bulk: `PUT /mls/key-packages/self/:client`. It replaces all existing key packages that match the given ciphersuites with the new key packages provided in the body. diff --git a/integration/test/API/Brig.hs b/integration/test/API/Brig.hs index 455ce04e9d9..8e1c390862a 100644 --- a/integration/test/API/Brig.hs +++ b/integration/test/API/Brig.hs @@ -302,6 +302,16 @@ deleteKeyPackages cid kps = do req <- baseRequest cid Brig Versioned ("/mls/key-packages/self/" <> cid.client) submit "DELETE" $ req & addJSONObject ["key_packages" .= kps] +replaceKeyPackages :: ClientIdentity -> [Ciphersuite] -> [ByteString] -> App Response +replaceKeyPackages cid suites kps = do + req <- + baseRequest cid Brig Versioned $ + "/mls/key-packages/self/" <> cid.client + submit "PUT" $ + req + & addQueryParams [("ciphersuites", intercalate "," (map (.code) suites))] + & addJSONObject ["key_packages" .= map (T.decodeUtf8 . Base64.encode) kps] + getSelf :: HasCallStack => String -> String -> App Response getSelf domain uid = do let user = object ["domain" .= domain, "id" .= uid] diff --git a/integration/test/MLS/Util.hs b/integration/test/MLS/Util.hs index d6bef3fc8a0..94c2d520f21 100644 --- a/integration/test/MLS/Util.hs +++ b/integration/test/MLS/Util.hs @@ -174,7 +174,8 @@ createMLSClient opts u = do -- | create and upload to backend uploadNewKeyPackage :: HasCallStack => ClientIdentity -> App String uploadNewKeyPackage cid = do - (kp, ref) <- generateKeyPackage cid + mls <- getMLSState + (kp, ref) <- generateKeyPackage cid mls.ciphersuite -- upload key package bindResponse (uploadKeyPackages cid [kp]) $ \resp -> @@ -182,10 +183,9 @@ uploadNewKeyPackage cid = do pure ref -generateKeyPackage :: HasCallStack => ClientIdentity -> App (ByteString, String) -generateKeyPackage cid = do - mls <- getMLSState - kp <- mlscli cid ["key-package", "create", "--ciphersuite", mls.ciphersuite.code] Nothing +generateKeyPackage :: HasCallStack => ClientIdentity -> Ciphersuite -> App (ByteString, String) +generateKeyPackage cid suite = do + kp <- mlscli cid ["key-package", "create", "--ciphersuite", suite.code] Nothing ref <- B8.unpack . Base64.encode <$> mlscli cid ["key-package", "ref", "-"] (Just kp) fp <- keyPackageFile cid ref liftIO $ BS.writeFile fp kp diff --git a/integration/test/Test/MLS/KeyPackage.hs b/integration/test/Test/MLS/KeyPackage.hs index 8f6cf9d20d3..78c7e87e0b5 100644 --- a/integration/test/Test/MLS/KeyPackage.hs +++ b/integration/test/Test/MLS/KeyPackage.hs @@ -56,10 +56,8 @@ testKeyPackageCount cs = do resp.status `shouldMatchInt` 200 resp.json %. "count" `shouldMatchInt` 0 - setMLSCiphersuite cs - let count = 10 - kps <- map fst <$> replicateM count (generateKeyPackage alice1) + kps <- map fst <$> replicateM count (generateKeyPackage alice1 cs) void $ uploadKeyPackages alice1 kps >>= getBody 201 bindResponse (countKeyPackages cs alice1) $ \resp -> do @@ -68,10 +66,79 @@ testKeyPackageCount cs = do testUnsupportedCiphersuite :: HasCallStack => App () testUnsupportedCiphersuite = do - setMLSCiphersuite (Ciphersuite "0x0002") + let suite = Ciphersuite "0x0002" + setMLSCiphersuite suite bob <- randomUser OwnDomain def bob1 <- createMLSClient def bob - (kp, _) <- generateKeyPackage bob1 + (kp, _) <- generateKeyPackage bob1 suite bindResponse (uploadKeyPackages bob1 [kp]) $ \resp -> do resp.status `shouldMatchInt` 400 resp.json %. "label" `shouldMatch` "mls-protocol-error" + +testReplaceKeyPackages :: HasCallStack => App () +testReplaceKeyPackages = do + alice <- randomUser OwnDomain def + [alice1, alice2] <- replicateM 2 $ createMLSClient def alice + let suite = Ciphersuite "0xf031" + + let checkCount cs n = + bindResponse (countKeyPackages cs alice1) $ \resp -> do + resp.status `shouldMatchInt` 200 + resp.json %. "count" `shouldMatchInt` n + + -- setup: upload a batch of key packages for each ciphersuite + void $ + replicateM 4 (fmap fst (generateKeyPackage alice1 def)) + >>= uploadKeyPackages alice1 + >>= getBody 201 + setMLSCiphersuite suite + void $ + replicateM 5 (fmap fst (generateKeyPackage alice1 suite)) + >>= uploadKeyPackages alice1 + >>= getBody 201 + + checkCount def 4 + checkCount suite 5 + + do + -- generate a new batch of key packages + (kps, refs) <- unzip <$> replicateM 3 (generateKeyPackage alice1 suite) + + -- replace old key packages with new + void $ replaceKeyPackages alice1 [suite] kps >>= getBody 201 + + checkCount def 4 + checkCount suite 3 + + -- claim all key packages one by one + claimed <- + replicateM 3 $ + bindResponse (claimKeyPackages suite alice2 alice) $ \resp -> do + resp.status `shouldMatchInt` 200 + ks <- resp.json %. "key_packages" & asList + k <- assertOne ks + k %. "key_package_ref" + + refs `shouldMatchSet` claimed + + checkCount def 4 + checkCount suite 0 + + do + -- replenish key packages for the second ciphersuite + void $ + replicateM 5 (fmap fst (generateKeyPackage alice1 suite)) + >>= uploadKeyPackages alice1 + >>= getBody 201 + + checkCount def 4 + checkCount suite 5 + + -- replace all key packages with fresh ones + kps1 <- replicateM 2 (fmap fst (generateKeyPackage alice1 def)) + kps2 <- replicateM 2 (fmap fst (generateKeyPackage alice1 suite)) + + void $ replaceKeyPackages alice1 [def, suite] (kps1 <> kps2) >>= getBody 201 + + checkCount def 2 + checkCount suite 2 diff --git a/libs/wire-api/src/Wire/API/MLS/CipherSuite.hs b/libs/wire-api/src/Wire/API/MLS/CipherSuite.hs index 1f358b58e61..7c51932a439 100644 --- a/libs/wire-api/src/Wire/API/MLS/CipherSuite.hs +++ b/libs/wire-api/src/Wire/API/MLS/CipherSuite.hs @@ -42,6 +42,7 @@ module Wire.API.MLS.CipherSuite where import Cassandra.CQL +import Control.Applicative import Control.Error (note) import Control.Lens ((?~)) import Crypto.Error @@ -51,18 +52,20 @@ import Crypto.PubKey.Ed25519 qualified as Ed25519 import Data.Aeson qualified as Aeson import Data.Aeson.Types (FromJSON (..), FromJSONKey (..), ToJSON (..), ToJSONKey (..)) import Data.Aeson.Types qualified as Aeson +import Data.Attoparsec.ByteString.Char8 qualified as Atto import Data.Bifunctor import Data.ByteArray hiding (index) import Data.ByteArray qualified as BA +import Data.ByteString.Conversion import Data.OpenApi qualified as S import Data.OpenApi.Internal.Schema qualified as S import Data.Proxy import Data.Schema import Data.Text qualified as T +import Data.Text.Encoding qualified as T import Data.Text.Lazy qualified as LT import Data.Text.Lazy.Builder qualified as LT import Data.Text.Lazy.Builder.Int qualified as LT -import Data.Text.Read qualified as T import Data.Word import Imports hiding (cs) import Web.HttpApiData @@ -85,11 +88,8 @@ instance S.ToParamSchema CipherSuite where & S.type_ ?~ S.OpenApiNumber instance FromHttpApiData CipherSuite where - parseUrlPiece t = do - (x, rest) <- first T.pack $ T.hexadecimal t - unless (T.null rest) $ - Left "Trailing characters after ciphersuite number" - pure (CipherSuite x) + parseUrlPiece = parseHeader . T.encodeUtf8 + parseHeader = first T.pack . runParser parser instance ToHttpApiData CipherSuite where toUrlPiece = @@ -99,6 +99,11 @@ instance ToHttpApiData CipherSuite where . LT.hexadecimal . cipherSuiteNumber +instance FromByteString CipherSuite where + parser = do + void $ Atto.try (optional (Atto.string "0x")) + CipherSuite <$> Atto.hexadecimal + data CipherSuiteTag = MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 | MLS_128_X25519Kyber768Draft00_AES128GCM_SHA256_Ed25519 diff --git a/libs/wire-api/src/Wire/API/MLS/KeyPackage.hs b/libs/wire-api/src/Wire/API/MLS/KeyPackage.hs index 1402ff17b9a..da2855013e8 100644 --- a/libs/wire-api/src/Wire/API/MLS/KeyPackage.hs +++ b/libs/wire-api/src/Wire/API/MLS/KeyPackage.hs @@ -62,7 +62,8 @@ import Wire.API.MLS.Serialisation import Wire.Arbitrary data KeyPackageUpload = KeyPackageUpload - {keyPackages :: [RawMLS KeyPackage]} + { keyPackages :: [RawMLS KeyPackage] + } deriving (FromJSON, ToJSON, S.ToSchema) via Schema KeyPackageUpload instance ToSchema KeyPackageUpload where diff --git a/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs b/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs index 6e780914b7c..8a09807bc5c 100644 --- a/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs +++ b/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs @@ -1186,6 +1186,15 @@ type CipherSuiteParam = "ciphersuite" CipherSuite +type MultipleCipherSuitesParam = + QueryParam' + [ Optional, + Strict, + Description "Comma-separated list of ciphersuites in hex format (e.g. 0xf031) - default is 0x0001" + ] + "ciphersuites" + (CommaSeparatedList CipherSuite) + type MLSKeyPackageAPI = "key-packages" :> ( Named @@ -1201,6 +1210,20 @@ type MLSKeyPackageAPI = :> ReqBody '[JSON] KeyPackageUpload :> MultiVerb 'POST '[JSON, MLS] '[RespondEmpty 201 "Key packages uploaded"] () ) + :<|> Named + "mls-key-packages-replace" + ( "self" + :> Summary "Upload a fresh batch of key packages and replace the old ones" + :> From 'V5 + :> Description "The request body should be a json object containing a list of base64-encoded key packages. Use this sparingly." + :> ZLocalUser + :> CanThrow 'MLSProtocolError + :> CanThrow 'MLSIdentityMismatch + :> CaptureClientId "client" + :> MultipleCipherSuitesParam + :> ReqBody '[JSON] KeyPackageUpload + :> MultiVerb 'PUT '[JSON, MLS] '[RespondEmpty 201 "Key packages replaced"] () + ) :<|> Named "mls-key-packages-claim" ( "claim" diff --git a/services/brig/src/Brig/API/MLS/CipherSuite.hs b/services/brig/src/Brig/API/MLS/CipherSuite.hs index ec6b9756787..da8182c0a41 100644 --- a/services/brig/src/Brig/API/MLS/CipherSuite.hs +++ b/services/brig/src/Brig/API/MLS/CipherSuite.hs @@ -15,15 +15,22 @@ -- You should have received a copy of the GNU Affero General Public License along -- with this program. If not, see . -module Brig.API.MLS.CipherSuite (getCipherSuite) where +module Brig.API.MLS.CipherSuite (getCipherSuite, getCipherSuites) where import Brig.API.Handler import Brig.API.MLS.KeyPackages.Validation import Imports import Wire.API.MLS.CipherSuite +getOneCipherSuite :: CipherSuite -> Handler r CipherSuiteTag +getOneCipherSuite s = + maybe + (mlsProtocolError "Unknown ciphersuite") + pure + (cipherSuiteTag s) + getCipherSuite :: Maybe CipherSuite -> Handler r CipherSuiteTag -getCipherSuite mSuite = case mSuite of - Nothing -> pure defCipherSuite - Just x -> - maybe (mlsProtocolError "Unknown ciphersuite") pure (cipherSuiteTag x) +getCipherSuite = maybe (pure defCipherSuite) getOneCipherSuite + +getCipherSuites :: Maybe [CipherSuite] -> Handler r [CipherSuiteTag] +getCipherSuites = maybe (pure [defCipherSuite]) (traverse getOneCipherSuite) diff --git a/services/brig/src/Brig/API/MLS/KeyPackages.hs b/services/brig/src/Brig/API/MLS/KeyPackages.hs index 4a3c244b356..35d1edba025 100644 --- a/services/brig/src/Brig/API/MLS/KeyPackages.hs +++ b/services/brig/src/Brig/API/MLS/KeyPackages.hs @@ -21,6 +21,7 @@ module Brig.API.MLS.KeyPackages claimLocalKeyPackages, countKeyPackages, deleteKeyPackages, + replaceKeyPackages, ) where @@ -37,6 +38,7 @@ import Brig.Federation.Client import Brig.IO.Intra import Control.Monad.Trans.Except import Control.Monad.Trans.Maybe +import Data.CommaSeparatedList import Data.Id import Data.Qualified import Data.Set qualified as Set @@ -157,3 +159,15 @@ deleteKeyPackages lusr c mSuite (unDeleteKeyPackages -> refs) = do assertMLSEnabled suite <- getCipherSuite mSuite lift $ wrapClient (Data.deleteKeyPackages (tUnqualified lusr) c suite refs) + +replaceKeyPackages :: + Local UserId -> + ClientId -> + Maybe (CommaSeparatedList CipherSuite) -> + KeyPackageUpload -> + Handler r () +replaceKeyPackages lusr c (fmap toList -> mSuites) upload = do + assertMLSEnabled + suites <- getCipherSuites mSuites + lift $ wrapClient (Data.deleteAllKeyPackages (tUnqualified lusr) c suites) + uploadKeyPackages lusr c upload diff --git a/services/brig/src/Brig/API/Public.hs b/services/brig/src/Brig/API/Public.hs index 2ce4307aecc..623792268ce 100644 --- a/services/brig/src/Brig/API/Public.hs +++ b/services/brig/src/Brig/API/Public.hs @@ -373,6 +373,7 @@ servantSitemap = mlsAPI :: ServerT MLSAPI (Handler r) mlsAPI = Named @"mls-key-packages-upload" uploadKeyPackages + :<|> Named @"mls-key-packages-replace" replaceKeyPackages :<|> Named @"mls-key-packages-claim" claimKeyPackages :<|> Named @"mls-key-packages-count" countKeyPackages :<|> Named @"mls-key-packages-delete" deleteKeyPackages diff --git a/services/brig/src/Brig/Data/MLS/KeyPackage.hs b/services/brig/src/Brig/Data/MLS/KeyPackage.hs index a03192f32e6..7aaccc4d16c 100644 --- a/services/brig/src/Brig/Data/MLS/KeyPackage.hs +++ b/services/brig/src/Brig/Data/MLS/KeyPackage.hs @@ -20,6 +20,7 @@ module Brig.Data.MLS.KeyPackage claimKeyPackage, countKeyPackages, deleteKeyPackages, + deleteAllKeyPackages, ) where @@ -37,6 +38,7 @@ import Data.Qualified import Data.Time.Clock import Data.Time.Clock.POSIX import Imports +import UnliftIO.Async import Wire.API.MLS.CipherSuite import Wire.API.MLS.KeyPackage import Wire.API.MLS.LeafNode @@ -142,6 +144,22 @@ deleteKeyPackages u c suite refs = deleteQuery :: PrepQuery W (UserId, ClientId, CipherSuiteTag, [KeyPackageRef]) () deleteQuery = "DELETE FROM mls_key_packages WHERE user = ? AND client = ? AND cipher_suite = ? AND ref in ?" +deleteAllKeyPackages :: + (MonadClient m, MonadUnliftIO m) => + UserId -> + ClientId -> + [CipherSuiteTag] -> + m () +deleteAllKeyPackages u c suites = + pooledForConcurrentlyN_ 16 suites $ \suite -> + retry x5 $ + write + deleteQuery + (params LocalQuorum (u, c, suite)) + where + deleteQuery :: PrepQuery W (UserId, ClientId, CipherSuiteTag) () + deleteQuery = "DELETE FROM mls_key_packages WHERE user = ? AND client = ? AND cipher_suite = ?" + -------------------------------------------------------------------------------- -- Utilities