Skip to content

Commit

Permalink
Add support for mempack
Browse files Browse the repository at this point in the history
  • Loading branch information
lehins committed Oct 25, 2024
1 parent 61b686c commit 0f1cec6
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 52 deletions.
3 changes: 2 additions & 1 deletion cardano-crypto-class/cardano-crypto-class.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ library
, heapwords
, io-classes >= 1.4.1
, memory
, mempack
, mtl
, nothunks
, primitive
, primitive >= 0.8
, serialise
, template-haskell
, th-compat
Expand Down
4 changes: 4 additions & 0 deletions cardano-crypto-class/src/Cardano/Crypto/Hash/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
Expand Down Expand Up @@ -66,6 +67,7 @@ import qualified Data.ByteString as BS
import qualified Data.ByteString.Base16 as Base16
import qualified Data.ByteString.Char8 as BSC
import Data.ByteString.Short (ShortByteString)
import Data.MemPack (StateT(StateT), FailT(FailT), MemPack, Unpack(Unpack))
import Data.Word (Word8)
import Numeric.Natural (Natural)

Expand Down Expand Up @@ -110,6 +112,8 @@ sizeHash _ = fromInteger (natVal (Proxy @(SizeHash h)))
newtype Hash h a = UnsafeHashRep (PackedBytes (SizeHash h))
deriving (Eq, Ord, Generic, NoThunks, NFData)

deriving instance HashAlgorithm h => MemPack (Hash h a)

-- | This instance is meant to be used with @TemplateHaskell@
--
-- >>> import Cardano.Crypto.Hash.Class (Hash)
Expand Down
134 changes: 83 additions & 51 deletions cardano-crypto-class/src/Cardano/Crypto/PackedBytes.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,22 @@ import Codec.Serialise.Encoding (encodeBytes)
import Control.DeepSeq
import Control.Monad (guard)
import Control.Monad.Primitive
import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Bits
import Data.ByteString
import Data.ByteString.Internal as BS (accursedUnutterablePerformIO,
fromForeignPtr, toForeignPtr)
import Data.ByteString.Short.Internal as SBS
import Data.MemPack (guardAdvanceUnpack, st_, MemPack(..), Pack(Pack))
import Data.MemPack.Buffer (Buffer(buffer), byteArrayToShortByteString, pinnedByteArrayToForeignPtr)
import Data.Primitive.ByteArray
import Data.Primitive.PrimArray (PrimArray(..), imapPrimArray, indexPrimArray)
import Data.Typeable
import Foreign.ForeignPtr
import Foreign.Ptr (castPtr)
import Foreign.Storable (peekByteOff)
import GHC.Exts
import GHC.ForeignPtr (ForeignPtr(ForeignPtr), ForeignPtrContents(PlainPtr))
#if MIN_VERSION_base(4,15,0)
import GHC.ForeignPtr (unsafeWithForeignPtr)
#endif
Expand Down Expand Up @@ -92,7 +95,38 @@ instance NFData (PackedBytes n) where
rnf PackedBytes32 {} = ()
rnf PackedBytes# {} = ()

instance Serialise (PackedBytes n) where
instance KnownNat n => MemPack (PackedBytes n) where
packedByteCount = fromIntegral @Integer @Int . natVal
{-# INLINE packedByteCount #-}
packM pb = do
let !len@(I# len#) = packedByteCount pb
i@(I# i#) <- state $ \i -> (i, i + len)
mba@(MutableByteArray mba#) <- ask
Pack $ \_ -> lift $ case pb of
PackedBytes8 w -> writeWord64BE mba i w
PackedBytes28 w0 w1 w2 w3 -> do
writeWord64BE mba i w0
writeWord64BE mba (i + 8) w1
writeWord64BE mba (i + 16) w2
writeWord32BE mba (i + 24) w3
PackedBytes32 w0 w1 w2 w3 -> do
writeWord64BE mba i w0
writeWord64BE mba (i + 8) w1
writeWord64BE mba (i + 16) w2
writeWord64BE mba (i + 24) w3
PackedBytes# ba# ->
st_ (copyByteArray# ba# 0# mba# i# len#)
{-# INLINE packM #-}
unpackM = do
let !len = fromIntegral @Integer @Int $ natVal' (proxy# :: Proxy# n)
curPos@(I# curPos#) <- guardAdvanceUnpack len
buf <- ask
pure $! buffer buf
(\ba# -> packBytes (SBS.SBS ba#) curPos)
(\addr# -> accursedUnutterablePerformIO $ packPinnedPtr (Ptr (addr# `plusAddr#` curPos#)))
{-# INLINE unpackM #-}

instance KnownNat n => Serialise (PackedBytes n) where
encode = encodeBytes . unpackPinnedBytes
decode = packPinnedBytesN <$> decodeBytes

Expand Down Expand Up @@ -221,53 +255,60 @@ packBytesMaybe bs offset = do
Just $ packBytes bs offset


packPinnedBytes8 :: ByteString -> PackedBytes 8
packPinnedBytes8 bs = unsafeWithByteStringPtr bs (fmap PackedBytes8 . (`peekWord64BE` 0))
{-# INLINE packPinnedBytes8 #-}

packPinnedBytes28 :: ByteString -> PackedBytes 28
packPinnedBytes28 bs =
unsafeWithByteStringPtr bs $ \ptr ->
PackedBytes28
<$> peekWord64BE ptr 0
<*> peekWord64BE ptr 8
<*> peekWord64BE ptr 16
<*> peekWord32BE ptr 24
{-# INLINE packPinnedBytes28 #-}

packPinnedBytes32 :: ByteString -> PackedBytes 32
packPinnedBytes32 bs =
unsafeWithByteStringPtr bs $ \ptr -> PackedBytes32 <$> peekWord64BE ptr 0
<*> peekWord64BE ptr 8
<*> peekWord64BE ptr 16
<*> peekWord64BE ptr 24
{-# INLINE packPinnedBytes32 #-}

packPinnedBytesN :: ByteString -> PackedBytes n
packPinnedBytesN bs =
case toShort bs of
SBS ba# -> PackedBytes# ba#
{-# INLINE packPinnedBytesN #-}
packPinnedPtr8 :: Ptr a -> IO (PackedBytes 8)
packPinnedPtr8 = fmap PackedBytes8 . (`peekWord64BE` 0)
{-# INLINE packPinnedPtr8 #-}

packPinnedPtr28 :: Ptr a -> IO (PackedBytes 28)
packPinnedPtr28 ptr =
PackedBytes28
<$> peekWord64BE ptr 0
<*> peekWord64BE ptr 8
<*> peekWord64BE ptr 16
<*> peekWord32BE ptr 24
{-# INLINE packPinnedPtr28 #-}

packPinnedPtr32 :: Ptr a -> IO (PackedBytes 32)
packPinnedPtr32 ptr =
PackedBytes32 <$> peekWord64BE ptr 0
<*> peekWord64BE ptr 8
<*> peekWord64BE ptr 16
<*> peekWord64BE ptr 24
{-# INLINE packPinnedPtr32 #-}

packPinnedPtrN :: forall n a. KnownNat n => Ptr a -> IO (PackedBytes n)
packPinnedPtrN (Ptr addr#) = pure $! PackedBytes# ba#
where
!(ByteArray ba#) = withMutableByteArray len $ \(MutableByteArray mba#) ->
st_ (copyAddrToByteArray# addr# mba# 0# len#)
!len@(I# len#) = fromIntegral @Integer @Int (natVal' (proxy# :: Proxy# n))
{-# INLINE packPinnedPtrN #-}

packPinnedBytesN :: KnownNat n => ByteString -> PackedBytes n
packPinnedBytesN bs = unsafeWithByteStringPtr bs packPinnedPtrN
{-# INLINE packPinnedBytesN #-}

packPinnedBytes :: forall n . KnownNat n => ByteString -> PackedBytes n
packPinnedBytes bs =
packPinnedPtr :: forall n a. KnownNat n => Ptr a -> IO (PackedBytes n)
packPinnedPtr bs =
let px = Proxy :: Proxy n
in case sameNat px (Proxy :: Proxy 8) of
Just Refl -> packPinnedBytes8 bs
Just Refl -> packPinnedPtr8 bs
Nothing -> case sameNat px (Proxy :: Proxy 28) of
Just Refl -> packPinnedBytes28 bs
Just Refl -> packPinnedPtr28 bs
Nothing -> case sameNat px (Proxy :: Proxy 32) of
Just Refl -> packPinnedBytes32 bs
Nothing -> packPinnedBytesN bs
{-# INLINE[1] packPinnedBytes #-}

Just Refl -> packPinnedPtr32 bs
Nothing -> packPinnedPtrN bs
{-# INLINE[1] packPinnedPtr #-}
{-# RULES
"packPinnedBytes8" packPinnedBytes = packPinnedBytes8
"packPinnedBytes28" packPinnedBytes = packPinnedBytes28
"packPinnedBytes32" packPinnedBytes = packPinnedBytes32
"packPinnedPtr8" packPinnedPtr = packPinnedPtr8
"packPinnedPtr28" packPinnedPtr = packPinnedPtr28
"packPinnedPtr32" packPinnedPtr = packPinnedPtr32
#-}

packPinnedBytes :: forall n . KnownNat n => ByteString -> PackedBytes n
packPinnedBytes bs = unsafeWithByteStringPtr bs packPinnedPtr
{-# INLINE packPinnedBytes #-}


--- Primitive architecture agnostic helpers

Expand Down Expand Up @@ -358,22 +399,13 @@ writeWord32BE (MutableByteArray mba#) (I# i#) w =
#endif
{-# INLINE writeWord32BE #-}

byteArrayToShortByteString :: ByteArray -> ShortByteString
byteArrayToShortByteString (ByteArray ba#) = SBS ba#
{-# INLINE byteArrayToShortByteString #-}

byteArrayToByteString :: ByteArray -> ByteString
byteArrayToByteString ba
byteArrayToByteString ba@(ByteArray ba#)
| isByteArrayPinned ba =
BS.fromForeignPtr (pinnedByteArrayToForeignPtr ba) 0 (sizeofByteArray ba)
BS.fromForeignPtr (pinnedByteArrayToForeignPtr ba#) 0 (sizeofByteArray ba)
| otherwise = SBS.fromShort (byteArrayToShortByteString ba)
{-# INLINE byteArrayToByteString #-}

pinnedByteArrayToForeignPtr :: ByteArray -> ForeignPtr a
pinnedByteArrayToForeignPtr (ByteArray ba#) =
ForeignPtr (byteArrayContents# ba#) (PlainPtr (unsafeCoerce# ba#))
{-# INLINE pinnedByteArrayToForeignPtr #-}

-- Usage of `accursedUnutterablePerformIO` here is safe because we only use it
-- for indexing into an immutable `ByteString`, which is analogous to
-- `Data.ByteString.index`. Make sure you know what you are doing before using
Expand Down
1 change: 1 addition & 0 deletions cardano-crypto-tests/cardano-crypto-tests.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ library
, deepseq
, formatting
, io-classes >= 1.1
, mempack
, mtl
, nothunks
, pretty-show
Expand Down
7 changes: 7 additions & 0 deletions cardano-crypto-tests/src/Test/Crypto/Hash.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import qualified Data.Bits as Bits (xor)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Short as SBS
import Data.Maybe (fromJust)
import Data.MemPack
import Data.Proxy (Proxy(..))
import Data.String (fromString)
import GHC.TypeLits
Expand Down Expand Up @@ -63,9 +64,15 @@ testHashAlgorithm p =
, testProperty "hashFromStringAsHex/fromString" $ prop_hash_hashFromStringAsHex_fromString @h @Float
, testProperty "show/read" $ prop_hash_show_read @h @Float
, testProperty "NoThunks" $ prop_no_thunks @(Hash h Int)
, testProperty "MemPack RoundTrip" $ prop_MemPackRoundTrip @(Hash h Int)
]
where n = hashAlgorithmName p

prop_MemPackRoundTrip :: forall a. (MemPack a, Eq a, Show a) => a -> Property
prop_MemPackRoundTrip a =
unpackError (pack a) === a .&&.
unpackError (packByteString a) === a

testSodiumHashAlgorithm
:: forall proxy h. NaCl.SodiumHashAlgorithm h
=> Lock
Expand Down

0 comments on commit 0f1cec6

Please sign in to comment.