Skip to content

Commit

Permalink
Use Word64 to represent a ClientId (#3713)
Browse files Browse the repository at this point in the history
* Use Word64 to represent a ClientId

* Rename client to clientToText

* Regenerate nix package

* Add openapi documentation for ClientId

* Fix golden tests

* Fix ClientId instances

* Preserve previous ClientId generation

* Add CHANGELOG entry

* Fix bound check in ClientId parser

* Document client ID generation
  • Loading branch information
pcapriotti authored Nov 16, 2023
1 parent acf8cc6 commit 0ea69ca
Show file tree
Hide file tree
Showing 66 changed files with 524 additions and 493 deletions.
1 change: 1 addition & 0 deletions changelog.d/5-internal/client-id-as-uint
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Represent client IDs as Word64 internally
63 changes: 44 additions & 19 deletions libs/types-common/src/Data/Id.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ module Data.Id

-- * Client IDs
ClientId (..),
newClientId,
clientToText,

-- * Other IDs
ConnId (..),
Expand All @@ -65,6 +65,7 @@ import Data.Attoparsec.ByteString ((<?>))
import Data.Attoparsec.ByteString.Char8 qualified as Atto
import Data.Bifunctor (first)
import Data.Binary
import Data.Binary.Builder qualified as Builder
import Data.ByteString.Builder (byteString)
import Data.ByteString.Char8 qualified as B8
import Data.ByteString.Conversion
Expand Down Expand Up @@ -308,39 +309,63 @@ instance Arbitrary ConnId where
-- only together with a 'UserId', stored in C*, and used as a handle for end-to-end encryption. It
-- lives as long as the device is registered. See also: 'ConnId'.
newtype ClientId = ClientId
{ client :: Text
{ clientToWord64 :: Word64
}
deriving (Eq, Ord, Show, ToByteString, Hashable, NFData, A.ToJSONKey, Generic)
deriving newtype (ToParamSchema, FromHttpApiData, ToHttpApiData, Binary)
deriving (Eq, Ord, Show)
deriving (FromJSON, ToJSON, S.ToSchema) via Schema ClientId

instance ToSchema ClientId where
schema = client .= parsedText "ClientId" clientIdFromByteString
instance ToParamSchema ClientId where
toParamSchema _ = toParamSchema (Proxy @Text)

instance FromHttpApiData ClientId where
parseUrlPiece = first T.pack . runParser parser . encodeUtf8

instance ToHttpApiData ClientId where
toUrlPiece = clientToText

clientToText :: ClientId -> Text
clientToText = toStrict . toLazyText . hexadecimal . clientToWord64

newClientId :: Word64 -> ClientId
newClientId = ClientId . toStrict . toLazyText . hexadecimal
instance ToSchema ClientId where
schema = withParser s parseClientId
where
s :: ValueSchemaP NamedSwaggerDoc ClientId Text
s =
clientToText .= schema
& doc . S.description
?~ "A 64-bit unsigned integer, represented as a hexadecimal numeral. \
\Any valid hexadecimal numeral is accepted, but the backend will only \
\produce representations with lowercase digits and no leading zeros"

clientIdFromByteString :: Text -> Either String ClientId
clientIdFromByteString txt =
if T.length txt <= 20 && T.all isHexDigit txt
then Right $ ClientId txt
else Left "Invalid ClientId"
parseClientId :: Text -> A.Parser ClientId
parseClientId = either fail pure . runParser parser . encodeUtf8

instance FromByteString ClientId where
parser = do
bs <- Atto.takeByteString
either fail pure $ clientIdFromByteString (cs bs)
num :: Integer <- Atto.hexadecimal
guard $ num <= fromIntegral (maxBound :: Word64)
pure (ClientId (fromIntegral num))

instance ToByteString ClientId where
builder = Builder.fromByteString . encodeUtf8 . clientToText

instance A.FromJSONKey ClientId where
fromJSONKey = A.FromJSONKeyTextParser $ either fail pure . clientIdFromByteString
fromJSONKey = A.FromJSONKeyTextParser parseClientId

instance A.ToJSONKey ClientId where
toJSONKey = A.toJSONKeyText clientToText

deriving instance Cql ClientId
instance Cql ClientId where
ctype = Tagged TextColumn
toCql = CqlText . clientToText
fromCql (CqlText t) = runParser parser (encodeUtf8 t)
fromCql _ = Left "ClientId: expected CqlText"

instance Arbitrary ClientId where
arbitrary = newClientId <$> arbitrary
arbitrary = ClientId <$> arbitrary

instance EncodeWire ClientId where
encodeWire t = encodeWire t . client
encodeWire t = encodeWire t . clientToText

instance DecodeWire ClientId where
decodeWire (DelimitedField _ x) = either fail pure (runParser parser x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ missing =
[ ( Domain "golden.example.com",
fromList
[ ( Id (fromJust (UUID.fromString "00000000-0000-0000-0000-000100000002")),
fromList [ClientId {client = "0"}, ClientId {client = "1"}]
fromList [ClientId 0, ClientId 1]
),
( Id (fromJust (UUID.fromString "00000000-0000-0000-0000-000200000000")),
fromList [ClientId {client = "0"}]
fromList [ClientId 0]
)
]
)
Expand All @@ -54,10 +54,10 @@ redundant =
[ ( Domain "golden.example.com",
fromList
[ ( Id (fromJust (UUID.fromString "00000000-0000-0000-0000-000100000003")),
fromList [ClientId {client = "0"}, ClientId {client = "1"}]
fromList [ClientId 0, ClientId 1]
),
( Id (fromJust (UUID.fromString "00000000-0000-0000-0000-000200000004")),
fromList [ClientId {client = "0"}]
fromList [ClientId 0]
)
]
)
Expand All @@ -72,10 +72,10 @@ deleted =
[ ( Domain "golden.example.com",
fromList
[ ( Id (fromJust (UUID.fromString "00000000-0000-0000-0000-000100000005")),
fromList [ClientId {client = "0"}, ClientId {client = "1"}]
fromList [ClientId 0, ClientId 1]
),
( Id (fromJust (UUID.fromString "00000000-0000-0000-0000-000200000006")),
fromList [ClientId {client = "0"}]
fromList [ClientId 0]
)
]
)
Expand All @@ -90,10 +90,10 @@ failed =
[ ( Domain "golden.example.com",
fromList
[ ( Id (fromJust (UUID.fromString "00000000-0000-0000-0000-000100000007")),
fromList [ClientId {client = "0"}, ClientId {client = "1"}]
fromList [ClientId 0, ClientId 1]
),
( Id (fromJust (UUID.fromString "00000000-0000-0000-0000-000200000008")),
fromList [ClientId {client = "0"}]
fromList [ClientId 0]
)
]
)
Expand All @@ -108,10 +108,10 @@ failedToConfirm =
[ ( Domain "golden.example.com",
fromList
[ ( Id (fromJust (UUID.fromString "00000000-0000-0000-0000-000100000009")),
fromList [ClientId {client = "0"}, ClientId {client = "1"}]
fromList [ClientId 0, ClientId 1]
),
( Id (fromJust (UUID.fromString "00000000-0000-0000-0000-000200000010")),
fromList [ClientId {client = "0"}]
fromList [ClientId 0]
)
]
)
Expand Down
8 changes: 4 additions & 4 deletions libs/wire-api/src/Wire/API/MLS/Credential.hs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ instance Show ClientIdentity where
show (ClientIdentity dom u c) =
show u
<> ":"
<> T.unpack (client c)
<> T.unpack (clientToText c)
<> "@"
<> T.unpack (domainText dom)

Expand Down Expand Up @@ -129,7 +129,7 @@ instance ParseMLS ClientIdentity where
uid <-
maybe (fail "Invalid UUID") (pure . Id) . fromASCIIBytes =<< getByteString 36
char ':'
cid <- newClientId <$> hexadecimal
cid <- ClientId <$> hexadecimal
char '@'
dom <-
either fail pure . (mkDomain . T.pack) =<< many' anyChar
Expand All @@ -141,7 +141,7 @@ parseX509ClientIdentity = do
uidBytes <- either fail pure $ B64URL.decodeUnpadded b64uuid
uid <- maybe (fail "Invalid UUID") (pure . Id) $ fromByteString (L.fromStrict uidBytes)
char '/'
cid <- newClientId <$> hexadecimal
cid <- ClientId <$> hexadecimal
char '@'
dom <-
either fail pure . (mkDomain . T.pack) =<< many' anyChar
Expand All @@ -151,7 +151,7 @@ instance SerialiseMLS ClientIdentity where
serialiseMLS cid = do
putByteString $ toASCIIBytes (toUUID (ciUser cid))
putCharUtf8 ':'
putStringUtf8 $ T.unpack (client (ciClient cid))
putStringUtf8 $ T.unpack (clientToText (ciClient cid))
putCharUtf8 '@'
putStringUtf8 $ T.unpack (domainText (ciDomain cid))

Expand Down
4 changes: 2 additions & 2 deletions libs/wire-api/src/Wire/API/Message.hs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ protolensToQualifiedNewOtrMessage protoMsg = do
}

protolensToClientId :: Proto.Otr.ClientId -> ClientId
protolensToClientId = newClientId . view Proto.Otr.client
protolensToClientId = ClientId . view Proto.Otr.client

qualifiedNewOtrMessageToProto :: QualifiedNewOtrMessage -> Proto.Otr.QualifiedNewOtrMessage
qualifiedNewOtrMessageToProto msg =
Expand Down Expand Up @@ -276,7 +276,7 @@ mkQualifiedOtrPayload sender entries dat strat =
clientIdToProtolens :: ClientId -> Proto.Otr.ClientId
clientIdToProtolens cid =
ProtoLens.defMessage
& Proto.Otr.client .~ (either error fst . Reader.hexadecimal $ client cid)
& Proto.Otr.client .~ (either error fst . Reader.hexadecimal $ clientToText cid)

--------------------------------------------------------------------------------
-- Priority
Expand Down
10 changes: 2 additions & 8 deletions libs/wire-api/src/Wire/API/Message/Proto.hs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ where

import Data.Id qualified as Id
import Data.ProtocolBuffers
import Data.Text.Lazy qualified as Text
import Data.Text.Lazy.Read (hexadecimal)
import Imports

--------------------------------------------------------------------------------
Expand Down Expand Up @@ -92,14 +90,10 @@ clientId :: Functor f => (Word64 -> f Word64) -> ClientId -> f ClientId
clientId f c = (\x -> c {_client = x}) <$> field f (_client c)

toClientId :: ClientId -> Id.ClientId
toClientId c = Id.newClientId $ getField (_client c)
toClientId c = Id.ClientId $ getField (_client c)

fromClientId :: Id.ClientId -> ClientId
fromClientId c =
either
(error "Invalid client ID")
(newClientId . fst)
(hexadecimal (Text.fromStrict $ Id.client c))
fromClientId = newClientId . Id.clientToWord64

--------------------------------------------------------------------------------
-- ClientEntry
Expand Down
6 changes: 3 additions & 3 deletions libs/wire-api/src/Wire/API/User/Client.hs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ instance ToSchema UserClientPrekeyMap where
( Map.singleton
(generateExample @UserId)
( Map.singleton
(newClientId 4940483633899001999)
(ClientId 4940483633899001999)
(Just (Prekey (PrekeyId 1) "pQABAQECoQBYIOjl7hw0D8YRNq..."))
)
)
Expand Down Expand Up @@ -415,8 +415,8 @@ instance ToSchema UserClients where
& Swagger.schema . Swagger.example
?~ toJSON
( Map.fromList
[ (generateExample @UserId, [newClientId 1684636986166846496, newClientId 4940483633899001999]),
(generateExample @UserId, [newClientId 6987438498444556166, newClientId 7940473633839002939])
[ (generateExample @UserId, [ClientId 1684636986166846496, ClientId 4940483633899001999]),
(generateExample @UserId, [ClientId 6987438498444556166, ClientId 7940473633839002939])
]
)

Expand Down
15 changes: 9 additions & 6 deletions libs/wire-api/src/Wire/API/User/Client/Prekey.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,13 @@ where

import Crypto.Hash (SHA256, hash)
import Data.Aeson (FromJSON (..), ToJSON (..))
import Data.Bits
import Data.ByteArray (convert)
import Data.ByteString qualified as BS
import Data.ByteString.Conversion (toByteString')
import Data.Id
import Data.OpenApi qualified as S
import Data.Schema
import Data.Text.Ascii (encodeBase16)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Data.Text.Encoding (encodeUtf8)
import Imports
import Wire.Arbitrary (Arbitrary (arbitrary), GenericUniform (..))

Expand All @@ -67,12 +66,16 @@ instance ToSchema Prekey where
<$> prekeyId .= field "id" schema
<*> prekeyKey .= field "key" schema

-- | Construct a new client ID from a prekey.
--
-- This works by taking the SHA256 hash of the prekey, truncating it to its
-- first 8 bytes, and interpreting the resulting bytestring as a big endian
-- Word64.
clientIdFromPrekey :: Prekey -> ClientId
clientIdFromPrekey =
ClientId
. decodeUtf8
. toByteString'
. encodeBase16
. foldl' (\w d -> (w `shiftL` 8) .|. fromIntegral d) 0
. BS.unpack
. BS.take 8
. convert
. hash @ByteString @SHA256
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
module Test.Wire.API.Golden.Generated.AddBotResponse_user where

import Data.Domain
import Data.Id (BotId (BotId), ClientId (ClientId, client), Id (Id))
import Data.Id
import Data.Qualified
import Data.UUID qualified as UUID (fromString)
import Imports (Maybe (Just, Nothing), fromJust, read, (.))
Expand All @@ -35,7 +35,7 @@ testObject_AddBotResponse_user_1 :: AddBotResponse
testObject_AddBotResponse_user_1 =
AddBotResponse
{ rsAddBotId = (BotId . Id) (fromJust (UUID.fromString "00000003-0000-0004-0000-000300000001")),
rsAddBotClient = ClientId {client = "e"},
rsAddBotClient = ClientId 0xe,
rsAddBotName =
Name
{ fromName =
Expand Down Expand Up @@ -63,7 +63,7 @@ testObject_AddBotResponse_user_2 :: AddBotResponse
testObject_AddBotResponse_user_2 =
AddBotResponse
{ rsAddBotId = (BotId . Id) (fromJust (UUID.fromString "00000001-0000-0003-0000-000200000004")),
rsAddBotClient = ClientId {client = "e"},
rsAddBotClient = ClientId 0xe,
rsAddBotName =
Name
{ fromName =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

module Test.Wire.API.Golden.Generated.ClientMismatch_user where

import Data.Id (ClientId (ClientId, client), Id (Id))
import Data.Id
import Data.Json.Util (toUTCTimeMillis)
import Data.UUID qualified as UUID (fromString)
import GHC.Exts (IsList (fromList))
Expand All @@ -34,7 +34,7 @@ testObject_ClientMismatch_user_1 =
{ userClients =
fromList
[ ( Id (fromJust (UUID.fromString "00000000-0000-0000-0000-000100000002")),
fromList [ClientId {client = "0"}, ClientId {client = "1"}]
fromList [ClientId 0, ClientId 1]
),
( Id (fromJust (UUID.fromString "00000000-0000-0000-0000-000200000000")),
fromList []
Expand All @@ -47,10 +47,10 @@ testObject_ClientMismatch_user_1 =
{ userClients =
fromList
[ ( Id (fromJust (UUID.fromString "00000004-0000-0004-0000-000700000000")),
fromList [ClientId {client = "0"}, ClientId {client = "1"}]
fromList [ClientId 0, ClientId 1]
),
( Id (fromJust (UUID.fromString "00000005-0000-0000-0000-000600000008")),
fromList [ClientId {client = "0"}, ClientId {client = "1"}]
fromList [ClientId 0, ClientId 1]
)
]
}
Expand Down
Loading

0 comments on commit 0ea69ca

Please sign in to comment.