diff --git a/dmq-node/cddl/Main.hs b/dmq-node/cddl/Main.hs index 2e5ef24..77124d0 100644 --- a/dmq-node/cddl/Main.hs +++ b/dmq-node/cddl/Main.hs @@ -11,9 +11,7 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -Wno-orphans #-} @@ -37,6 +35,7 @@ import Data.Bool (bool) import Data.ByteString.Base16.Lazy qualified as BL.Base16 import Data.ByteString.Lazy qualified as BL import Data.ByteString.Lazy.Char8 qualified as BL.Char8 +import Data.Functor ((<&>)) import Text.Printf import System.Directory (doesDirectoryExist) @@ -55,12 +54,16 @@ import DMQ.Protocol.LocalMsgNotification.Codec import DMQ.Protocol.LocalMsgNotification.Type as LocalMsgNotification import DMQ.Protocol.SigSubmission.Codec import DMQ.Protocol.SigSubmission.Type +import DMQ.Protocol.SigSubmissionV2.Type (SigSubmissionV2) +import DMQ.Protocol.SigSubmissionV2.Type qualified as SigSubmissionV2 +import DMQ.Protocol.SigSubmissionV2.Test qualified as SigSubmissionV2.Test +import DMQ.Protocol.SigSubmissionV2.Codec.CDDL -- import Test.QuickCheck hiding (Result (..)) import Test.QuickCheck.Instances.ByteString () import Test.Tasty (TestTree, adjustOption, defaultMain, testGroup) import Test.Tasty.HUnit -import Test.Tasty.QuickCheck (QuickCheckMaxSize (..)) +import Test.Tasty.QuickCheck main :: IO () main = do @@ -69,7 +72,8 @@ main = do tests :: CDDLSpecs -> TestTree tests CDDLSpecs { cddlSig, - cddlLocalMsgNotification + cddlLocalMsgNotification, + cddlSigSubmissionV2 } = adjustOption (const $ QuickCheckMaxSize 10) $ testGroup "cddl" @@ -77,8 +81,11 @@ tests CDDLSpecs { cddlSig, -- validate decoder by generating messages from the specification [ testCase "Sig" (unit_decodeSig cddlSig) , testCase "LocalMsgNotification" (unit_decodeLocalMsgNotification cddlLocalMsgNotification) + , testCase "SigSubmissionV2" (unit_decodeSigSubmissionV2 cddlSigSubmissionV2) ] - -- TODO: validate `LocalMsgNotification` encoder + -- TODO: validate `LocalMsgNotification` encoder (this should be done in + -- `DMQ.Protocol.LocalMsgNotification.Test` module, see + -- `DMQ.Protocol.SigSubmissionV2.Test.prop_encoding` for an example) ] newtype CDDLSpec ps = CDDLSpec BL.ByteString @@ -87,13 +94,14 @@ type AnnSigRawWithSignedBytes = BL.ByteString -> SigRawWithSignedBytes StandardC data CDDLSpecs = CDDLSpecs { cddlSig :: CDDLSpec AnnSigRawWithSignedBytes, - cddlLocalMsgNotification :: CDDLSpec (LocalMsgNotification (Sig StandardCrypto)) + cddlLocalMsgNotification :: CDDLSpec (LocalMsgNotification (Sig StandardCrypto)), + cddlSigSubmissionV2 :: CDDLSpec (SigSubmissionV2 SigSubmissionV2.Test.SigId SigSubmissionV2.Test.Sig) } unit_decodeSig :: CDDLSpec AnnSigRawWithSignedBytes -> Assertion -unit_decodeSig spec = validateDecoder spec decodeSig 100 +unit_decodeSig spec = validateDecoder' spec decodeSig 100 unit_decodeLocalMsgNotification :: CDDLSpec (LocalMsgNotification (Sig StandardCrypto)) -> Assertion @@ -121,6 +129,21 @@ unit_decodeLocalMsgNotification spec = term -> term +unit_decodeSigSubmissionV2 + :: CDDLSpec (SigSubmissionV2 SigSubmissionV2.Test.SigId SigSubmissionV2.Test.Sig) + -> Assertion +unit_decodeSigSubmissionV2 spec = + validateDecoder + (Just indefiniteListFix) + spec + sigSubmissionV2Codec + [ SomeAgency $ SigSubmissionV2.SingSigIds SigSubmissionV2.SingBlocking + , SomeAgency $ SigSubmissionV2.SingSigIds SigSubmissionV2.SingNonBlocking + , SomeAgency SigSubmissionV2.SingSigs + , SomeAgency SigSubmissionV2.SingIdle + ] + 100 + -- -- utils -- @@ -133,9 +156,9 @@ unit_decodeLocalMsgNotification spec = -- The `CDDL_INCLUDE_PATH` environment variable must be set. cddlc :: FilePath -> IO BL.ByteString cddlc path = do - (exitCode, cddl, _) <- readProcessWithExitCode "cddlc" ["-u", "-2", "-t", "cddl", path] mempty + (exitCode, cddl, stderr) <- readProcessWithExitCode "cddlc" ["-u", "-2", "-t", "cddl", path] mempty unless (exitCode == ExitSuccess) $ - die $ printf "cddlc failed on \"%s\" with %s " path (show exitCode) + die $ printf "cddlc failed on \"%s\" with %s\n%s " path (show exitCode) (BL.Char8.unpack stderr) return cddl @@ -148,18 +171,72 @@ readCDDLSpecs = do sigSpec <- cddlc (dir "sig.cddl") localMessageNotificationSpec <- cddlc (dir "local-msg-notification.cddl") + sigSubmissionV2Spec <- cddlc (dir "sig-submission-v2.cddl") return CDDLSpecs { cddlSig = CDDLSpec sigSpec, - cddlLocalMsgNotification = CDDLSpec localMessageNotificationSpec + cddlLocalMsgNotification = CDDLSpec localMessageNotificationSpec, + cddlSigSubmissionV2 = CDDLSpec sigSubmissionV2Spec } +validateDecoder :: Maybe (CBOR.Term -> CBOR.Term) + -- ^ transform a generated term + -> CDDLSpec ps + -> Codec ps CBOR.DeserialiseFailure IO BL.ByteString + -> [SomeAgency ps] + -> Int + -> Assertion +validateDecoder transform (CDDLSpec spec) codec stoks rounds = do + eterms <- runExceptT $ generateCBORFromSpec spec rounds + case eterms of + Left err -> assertFailure err + Right terms -> + forM_ terms $ \(generated_term, encoded_term) -> do + let encoded_term' = case transform of + Nothing -> encoded_term + Just tr -> case CBOR.deserialiseFromBytes CBOR.decodeTerm encoded_term of + Right (rest, term) | BL.null rest + -> CBOR.toLazyByteString (CBOR.encodeTerm (tr term)) + Right _ -> error "validateDecoder: trailing bytes" + Left err -> error $ "validateDecoder: decoding error: " + ++ show err + Right (_, decoded_term) = + CBOR.deserialiseFromBytes CBOR.decodeTerm encoded_term' + res <- decodeMsg encoded_term' + case res of + Just errs -> assertFailure $ concat + [ "decoding failures:\n" + , unlines (map show errs) + , "while decoding:\n" + , show decoded_term + , "\n" + , "generated term:\n" + , BL.Char8.unpack generated_term + ] + Nothing -> return () + where + -- | Try decode at all given agencies. If one succeeds return + -- 'Nothing' otherwise return all 'DeserialiseFailure's. + -- + decodeMsg :: BL.ByteString + -> IO (Maybe [CBOR.DeserialiseFailure]) + decodeMsg bs = + -- sequence [Nothing, ...] = Nothing + fmap (sequence :: [Maybe CBOR.DeserialiseFailure] -> Maybe [CBOR.DeserialiseFailure]) $ + forM stoks $ \(SomeAgency (stok :: StateToken st)) -> do + decoder <- (decode codec stok :: IO (DecodeStep BL.ByteString CBOR.DeserialiseFailure IO (SomeMessage st))) + res <- runDecoder [bs] decoder + return $ case res of + Left err -> Just err + Right {} -> Nothing + + -validateDecoder :: CDDLSpec a +validateDecoder' :: CDDLSpec a -> (forall s. CBOR.Decoder s a) -> Int -> Assertion -validateDecoder (CDDLSpec spec) decoder rounds = do +validateDecoder' (CDDLSpec spec) decoder rounds = do eterms <- runExceptT $ generateCBORFromSpec spec rounds case eterms of Left err -> assertFailure err @@ -215,8 +292,8 @@ validateAnnotatedDecoder transform (CDDLSpec spec) codec stoks rounds = do Just tr -> case CBOR.deserialiseFromBytes CBOR.decodeTerm encoded_term of Right (rest, term) | BL.null rest -> CBOR.toLazyByteString (CBOR.encodeTerm (tr term)) - Right _ -> error "validateDecoder: trailing bytes" - Left err -> error $ "validateDecoder: decoding error: " + Right _ -> error "validateAnnotatedDecoder: trailing bytes" + Left err -> error $ "validateAnnotatedDecoder: decoding error: " ++ show err Right (_, decoded_term) = @@ -243,7 +320,7 @@ validateAnnotatedDecoder transform (CDDLSpec spec) codec stoks rounds = do decodeMsg bs = -- sequence [Nothing, ...] = Nothing fmap sequence $ - forM stoks $ \(a@(SomeAgency (stok :: StateToken st))) -> do + forM stoks $ \a@(SomeAgency (stok :: StateToken st)) -> do decoder <- decode codec stok res <- runDecoder [bs] decoder return $ case res of @@ -270,17 +347,28 @@ generateCBORFromSpec spec rounds = do . readProcessWithExitCode "diag2cbor.rb" ["-"] - unpackResult :: IO (ExitCode, BL.ByteString, BL.ByteString) - -> IO (Either String BL.ByteString) - unpackResult r = r >>= \case - (ExitFailure _, _, err) -> return (Left $ BL.Char8.unpack err) - (ExitSuccess, bytes, _) -> return (Right bytes) - - - withTemporaryFile :: BL.ByteString - -> (FilePath -> IO a) -> IO a - withTemporaryFile bs k = - withTempFile "." "tmp" $ - \fileName h -> BL.hPut h bs - >> hClose h - >> k fileName +-- | The cddl spec cannot differentiate between fix-length list encoding and +-- infinite-length encoding. The cddl tool always generates fix-length +-- encoding but tx-submission and object-diffusion codecs are accepting only +-- indefinite-length encoding. +-- +indefiniteListFix :: CBOR.Term -> CBOR.Term +indefiniteListFix term = + case term of + TList [TInt tag, TList l] -> TList [TInt tag, TListI l] + _ -> term + + +unpackResult :: IO (ExitCode, BL.ByteString, BL.ByteString) + -> IO (Either String BL.ByteString) +unpackResult r = r <&> \case + (ExitFailure _, _, err) -> (Left $ BL.Char8.unpack err) + (ExitSuccess, bytes, _) -> (Right bytes) + + +withTemporaryFile :: BL.ByteString -> (FilePath -> IO a) -> IO a +withTemporaryFile bs k = + withTempFile "." "tmp" $ + \fileName h -> BL.hPut h bs + >> hClose h + >> k fileName diff --git a/dmq-node/cddl/specs/sig-submission-v2.cddl b/dmq-node/cddl/specs/sig-submission-v2.cddl new file mode 100644 index 0000000..49cae7d --- /dev/null +++ b/dmq-node/cddl/specs/sig-submission-v2.cddl @@ -0,0 +1,31 @@ +; +; SigSubmission v2 mini-protocol +; + +; reference implementation of the codec in: +; dmq-node/src/DMQ/Protocol/SigSubmission/V2/Codec.hs + +sigSubmissionV2Message + = + ; corresponds to either MsgRequestSigIdsBlocking or + ; MsgRequestSigIdsNonBlocking in the spec + msgRequestSigIds + / msgReplySigIds + / msgReplyNoSigIds + / msgRequestSigs + / msgReplySigs + / msgDone + + +msgRequestSigIds = [1, blocking, sigCount, sigCount] +msgReplySigIds = [2, [*messageTuple] ] +msgReplyNoSigIds = [3] +msgRequestSigs = [4, [*sig.messageId] ] +msgReplySigs = [5, [*sig.message] ] +msgDone = [6] + +blocking = false / true +sigCount = sig.word16 +messageTuple = [sig.messageId, sig.messageSize] + +;# import sig as sig diff --git a/dmq-node/cddl/specs/sig.cddl b/dmq-node/cddl/specs/sig.cddl index 10740d1..a0b68b2 100644 --- a/dmq-node/cddl/specs/sig.cddl +++ b/dmq-node/cddl/specs/sig.cddl @@ -13,11 +13,13 @@ messagePayload = [ messageId = bstr messageBody = bstr +messageSize = word32 kesSignature = bstr .size 448 kesPeriod = word64 operationalCertificate = [ bstr .size 32, word64, word64, bstr .size 64 ] coldVerificationKey = bstr .size 32 expiresAt = word32 +word16 = uint .size 2; 2 bytes word32 = uint .size 4; 4 bytes word64 = uint .size 8; 8 bytes diff --git a/dmq-node/dmq-node.cabal b/dmq-node/dmq-node.cabal index bf79b45..58448e7 100644 --- a/dmq-node/dmq-node.cabal +++ b/dmq-node/dmq-node.cabal @@ -83,6 +83,12 @@ library DMQ.Protocol.SigSubmission.Codec DMQ.Protocol.SigSubmission.Type DMQ.Protocol.SigSubmission.Validate + DMQ.Protocol.SigSubmissionV2.Codec + DMQ.Protocol.SigSubmissionV2.Inbound + DMQ.Protocol.SigSubmissionV2.Outbound + DMQ.Protocol.SigSubmissionV2.Type + DMQ.SigSubmissionV2.Inbound + DMQ.SigSubmissionV2.Outbound DMQ.Tracer build-depends: @@ -101,6 +107,7 @@ library cardano-ledger-core, cardano-ledger-shelley, cardano-slotting, + cardano-strict-containers, cborg >=0.2.1 && <0.3, containers >=0.5 && <0.8, contra-tracer >=0.1 && <0.3, @@ -114,11 +121,13 @@ library mtl, network ^>=3.2.7, network-mux ^>=0.9.1, + nothunks, optparse-applicative >=0.18 && <0.20, ouroboros-consensus, ouroboros-consensus-cardano, ouroboros-consensus-diffusion, ouroboros-network:{ouroboros-network, api, framework, orphan-instances, protocols} ^>=0.23, + quiet, random ^>=1.2, singletons, text >=1.2.4 && <2.2, @@ -172,8 +181,13 @@ test-suite dmq-tests DMQ.Protocol.LocalMsgNotification.Test DMQ.Protocol.LocalMsgSubmission.Test DMQ.Protocol.SigSubmission.Test + DMQ.Protocol.SigSubmissionV2.Codec.CDDL + DMQ.Protocol.SigSubmissionV2.Direct + DMQ.Protocol.SigSubmissionV2.Test Test.DMQ.NodeToClient Test.DMQ.NodeToNode + Test.DMQ.SigSubmission.App + Test.DMQ.SigSubmission.Types type: exitcode-stdio-1.0 hs-source-dirs: test @@ -188,13 +202,17 @@ test-suite dmq-tests cborg, containers, contra-tracer, + deepseq, dmq-node, - io-classes:{io-classes, si-timers}, + hashable, + io-classes:{io-classes, si-timers, strict-mvar, strict-stm}, io-sim, kes-agent-crypto, + nothunks, ouroboros-consensus-cardano, - ouroboros-network:{api, framework, protocols, protocols-tests-lib, tests-lib}, + ouroboros-network:{ouroboros-network, api, framework, ouroboros-network-tests-lib, protocols, protocols-tests-lib, tests-lib}, quickcheck-instances, + random, serialise, tasty, tasty-quickcheck, @@ -219,8 +237,14 @@ test-suite dmq-cddl extensions type: exitcode-stdio-1.0 - hs-source-dirs: cddl + hs-source-dirs: + cddl + test + main-is: Main.hs + other-modules: + DMQ.Protocol.SigSubmissionV2.Codec.CDDL + DMQ.Protocol.SigSubmissionV2.Test if flag(cddl) buildable: True @@ -229,6 +253,7 @@ test-suite dmq-cddl default-language: Haskell2010 build-depends: + QuickCheck, base >=4.14 && <4.23, base16-bytestring, bytestring, @@ -236,8 +261,10 @@ test-suite dmq-cddl directory, dmq-node, filepath, + io-classes, kes-agent-crypto, mtl, + ouroboros-network:{api, protocols-tests-lib, tests-lib}, process-extras, quickcheck-instances, serialise, @@ -245,7 +272,7 @@ test-suite dmq-cddl tasty-hunit, tasty-quickcheck, temporary, - typed-protocols, + typed-protocols:{typed-protocols, codec-properties}, ghc-options: -threaded diff --git a/dmq-node/src/DMQ/Configuration.hs b/dmq-node/src/DMQ/Configuration.hs index fcc6706..f816390 100644 --- a/dmq-node/src/DMQ/Configuration.hs +++ b/dmq-node/src/DMQ/Configuration.hs @@ -73,9 +73,9 @@ import Ouroboros.Network.PeerSelection.LedgerPeers.Type import Ouroboros.Network.PeerSelection.PeerSharing (PeerSharing (..)) import Ouroboros.Network.Server.RateLimiting (AcceptedConnectionsLimit (..)) import Ouroboros.Network.Snocket (LocalAddress (..), RemoteAddress) -import Ouroboros.Network.TxSubmission.Inbound.V2 (TxDecisionPolicy (..)) import DMQ.Configuration.Topology (NoExtraConfig (..), NoExtraFlags (..)) +import Ouroboros.Network.TxSubmission.Inbound.V2 (TxDecisionPolicy(..)) -- | Configuration comes in two flavours paramemtrised by `f` functor: -- `PartialConfig` is using `Last` and `Configuration` is using an identity diff --git a/dmq-node/src/DMQ/Diffusion/NodeKernel.hs b/dmq-node/src/DMQ/Diffusion/NodeKernel.hs index fc71b48..68f3fee 100644 --- a/dmq-node/src/DMQ/Diffusion/NodeKernel.hs +++ b/dmq-node/src/DMQ/Diffusion/NodeKernel.hs @@ -52,7 +52,7 @@ import Ouroboros.Network.TxSubmission.Mempool.Simple (Mempool (..), MempoolSeq (..), WithIndex (..)) import Ouroboros.Network.TxSubmission.Mempool.Simple qualified as Mempool -import DMQ.Configuration +import DMQ.Configuration as Conf import DMQ.Protocol.SigSubmission.Type (Sig (sigExpiresAt, sigId), SigId) import DMQ.Tracer @@ -206,7 +206,7 @@ withNodeKernel tracer then WithEventType "SigSubmission.Logic" >$< tracer else nullTracer) nullTracer - defaultSigDecisionPolicy + Conf.defaultSigDecisionPolicy sigChannelVar sigSharedTxStateVar) $ \sigLogicThread -> diff --git a/dmq-node/src/DMQ/NodeToClient/LocalMsgSubmission.hs b/dmq-node/src/DMQ/NodeToClient/LocalMsgSubmission.hs index 8552946..eff285e 100644 --- a/dmq-node/src/DMQ/NodeToClient/LocalMsgSubmission.hs +++ b/dmq-node/src/DMQ/NodeToClient/LocalMsgSubmission.hs @@ -49,7 +49,9 @@ data TraceLocalMsgSubmission msg msgid = TraceReceivedMsg msgid -- ^ A signature was received. | TraceSubmitFailure msgid SigValidationError + -- ^ A signature was rejected with the given validation failure. | TraceSubmitAccept msgid + -- ^ A signature was validated and accepted into the mempool. deriving instance (Show msg, Show msgid) diff --git a/dmq-node/src/DMQ/NodeToNode.hs b/dmq-node/src/DMQ/NodeToNode.hs index a489f08..aed77af 100644 --- a/dmq-node/src/DMQ/NodeToNode.hs +++ b/dmq-node/src/DMQ/NodeToNode.hs @@ -57,9 +57,13 @@ import Cardano.KESAgent.KES.Crypto (Crypto (..)) import DMQ.Configuration (Configuration, Configuration' (..), I (..)) import DMQ.Diffusion.NodeKernel (NodeKernel (..)) import DMQ.NodeToNode.Version -import DMQ.Protocol.SigSubmission.Codec -import DMQ.Protocol.SigSubmission.Type import DMQ.Protocol.SigSubmission.Validate (SigValidationError) +import DMQ.Protocol.SigSubmissionV2.Codec +import DMQ.Protocol.SigSubmissionV2.Inbound (sigSubmissionV2InboundPeerPipelined) +import DMQ.Protocol.SigSubmissionV2.Outbound (sigSubmissionV2OutboundPeer) +import DMQ.Protocol.SigSubmissionV2.Type +import DMQ.SigSubmissionV2.Inbound (sigSubmissionInbound) +import DMQ.SigSubmissionV2.Outbound (sigSubmissionOutbound) import DMQ.Tracer import Ouroboros.Network.BlockFetch.ClientRegistry (bracketKeepAliveClient) @@ -85,9 +89,8 @@ import Ouroboros.Network.PeerSelection (PeerSharing (..)) import Ouroboros.Network.PeerSharing (bracketPeerSharingClient, peerSharingClient, peerSharingServer) import Ouroboros.Network.Snocket (RemoteAddress) -import Ouroboros.Network.TxSubmission.Inbound.V2 as SigSubmission +import Ouroboros.Network.TxSubmission.Inbound.V2 import Ouroboros.Network.TxSubmission.Mempool.Reader -import Ouroboros.Network.TxSubmission.Outbound import Ouroboros.Network.OrphanInstances () @@ -106,9 +109,7 @@ import Ouroboros.Network.Protocol.PeerSharing.Codec (byteLimitsPeerSharing, codecPeerSharing, timeLimitsPeerSharing) import Ouroboros.Network.Protocol.PeerSharing.Server (peerSharingServerPeer) import Ouroboros.Network.Protocol.PeerSharing.Type qualified as Protocol -import Ouroboros.Network.Protocol.TxSubmission2.Client (txSubmissionClientPeer) -import Ouroboros.Network.Protocol.TxSubmission2.Server - (txSubmissionServerPeerPipelined) + -- TODO: if we add `versionNumber` to `ctx` we could use `RunMiniProtocolCb`. -- This makes sense, since `ctx` already contains `versionData`. @@ -228,37 +229,12 @@ ntnApps -> ExpandedInitiatorContext addr m -> Channel m BL.ByteString -> m ((), Maybe BL.ByteString) - aSigSubmissionClient version + aSigSubmissionClient _version ExpandedInitiatorContext { eicConnectionId = connId, eicControlMessage = controlMessage } channel = - runAnnotatedPeerWithLimits - (if sigSubmissionClientProtocolTracer - then WithEventType "SigSubmission.Protocol.Client" . Mx.WithBearer connId >$< tracer - else nullTracer) - sigSubmissionCodec - sigSubmissionSizeLimits - sigSubmissionTimeLimits - channel - $ txSubmissionClientPeer - $ txSubmissionOutbound - (if sigSubmissionOutboundTracer - then WithEventType "SigSubmission.Outbound" . Mx.WithBearer connId >$< tracer - else nullTracer) - _MAX_SIGS_TO_ACK - mempoolReader - version - controlMessage - - - aSigSubmissionServer - :: NodeToNodeVersion - -> ResponderContext addr - -> Channel m BL.ByteString - -> m ((), Maybe BL.ByteString) - aSigSubmissionServer _version ResponderContext { rcConnectionId = connId } channel = - SigSubmission.withPeer + withPeer (if sigSubmissionLogicTracer then WithEventType "SigSubmission.Logic" . Mx.WithBearer connId >$< tracer else nullTracer) @@ -279,14 +255,39 @@ ntnApps sigSubmissionSizeLimits sigSubmissionTimeLimits channel - $ txSubmissionServerPeerPipelined - $ txSubmissionInboundV2 + $ sigSubmissionV2InboundPeerPipelined + $ sigSubmissionInbound (if sigSubmissionInboundTracer then WithEventType "SigSubmission.Inbound" . Mx.WithBearer connId >$< tracer else nullTracer) _SIG_SUBMISSION_INIT_DELAY mempoolWriter peerSigAPI + controlMessage + + + aSigSubmissionServer + :: NodeToNodeVersion + -> ResponderContext addr + -> Channel m BL.ByteString + -> m ((), Maybe BL.ByteString) + aSigSubmissionServer version ResponderContext { rcConnectionId = connId } channel = + runAnnotatedPeerWithLimits + (if sigSubmissionClientProtocolTracer + then WithEventType "SigSubmission.Protocol.Client" . Mx.WithBearer connId >$< tracer + else nullTracer) + sigSubmissionCodec + sigSubmissionSizeLimits + sigSubmissionTimeLimits + channel + $ sigSubmissionV2OutboundPeer + $ sigSubmissionOutbound + (if sigSubmissionOutboundTracer + then WithEventType "SigSubmission.Outbound" . Mx.WithBearer connId >$< tracer + else nullTracer) + _MAX_SIGS_TO_ACK + mempoolReader + version aKeepAliveClient @@ -535,7 +536,7 @@ initiatorAndResponderProtocols limitsAndTimeouts data Codecs crypto addr m = Codecs { - sigSubmissionCodec :: AnnotatedCodec (SigSubmission crypto) + sigSubmissionCodec :: AnnotatedCodec (SigSubmissionV2 SigId (Sig crypto)) CBOR.DeserialiseFailure m BL.ByteString , keepAliveCodec :: Codec KeepAlive CBOR.DeserialiseFailure m BL.ByteString @@ -543,15 +544,15 @@ data Codecs crypto addr m = CBOR.DeserialiseFailure m BL.ByteString } -dmqCodecs :: ( Crypto crypto - , MonadST m +dmqCodecs :: ( MonadST m + , Crypto crypto ) => (addr -> CBOR.Encoding) -> (forall s. CBOR.Decoder s addr) -> Codecs crypto addr m dmqCodecs encodeAddr decodeAddr = Codecs { - sigSubmissionCodec = codecSigSubmission + sigSubmissionCodec = anncodecSigSubmissionV2' , keepAliveCodec = codecKeepAlive_v2 , peerSharingCodec = codecPeerSharing encodeAddr decodeAddr } @@ -562,9 +563,9 @@ data LimitsAndTimeouts crypto addr = sigSubmissionLimits :: MiniProtocolLimits , sigSubmissionSizeLimits - :: ProtocolSizeLimits (SigSubmission crypto) BL.ByteString + :: ProtocolSizeLimits (SigSubmissionV2 SigId (Sig crypto)) BL.ByteString , sigSubmissionTimeLimits - :: ProtocolTimeLimits (SigSubmission crypto) + :: ProtocolTimeLimits (SigSubmissionV2 SigId (Sig crypto)) -- keep-alive , keepAliveLimits @@ -591,8 +592,8 @@ dmqLimitsAndTimeouts = -- TODO maximumIngressQueue = maxBound } - , sigSubmissionTimeLimits = timeLimitsSigSubmission - , sigSubmissionSizeLimits = byteLimitsSigSubmission size + , sigSubmissionTimeLimits = timeLimitsSigSubmissionV2 + , sigSubmissionSizeLimits = byteLimitsSigSubmissionV2 size , keepAliveLimits = MiniProtocolLimits { @@ -655,7 +656,7 @@ stdVersionDataNTN networkMagic diffusionMode peerSharing = } -- TODO: choose wisely, is a protocol parameter. -_MAX_SIGS_TO_ACK :: NumTxIdsToAck +_MAX_SIGS_TO_ACK :: NumIdsAck _MAX_SIGS_TO_ACK = 20 _SIG_SUBMISSION_INIT_DELAY :: TxSubmissionInitDelay diff --git a/dmq-node/src/DMQ/NodeToNode/Version.hs b/dmq-node/src/DMQ/NodeToNode/Version.hs index 752c97a..adfee08 100644 --- a/dmq-node/src/DMQ/NodeToNode/Version.hs +++ b/dmq-node/src/DMQ/NodeToNode/Version.hs @@ -31,20 +31,24 @@ import Ouroboros.Network.Protocol.Handshake (Accept (..)) import Ouroboros.Network.OrphanInstances () -data NodeToNodeVersion = - NodeToNodeV_1 +data NodeToNodeVersion + = NodeToNodeV_1 + | NodeToNodeV_2 deriving (Eq, Ord, Enum, Bounded, Show, Generic, NFData) instance Aeson.ToJSON NodeToNodeVersion where toJSON NodeToNodeV_1 = Aeson.toJSON (1 :: Int) + toJSON NodeToNodeV_2 = Aeson.toJSON (2 :: Int) instance Aeson.ToJSONKey NodeToNodeVersion where nodeToNodeVersionCodec :: CodecCBORTerm (Text, Maybe Int) NodeToNodeVersion nodeToNodeVersionCodec = CodecCBORTerm { encodeTerm, decodeTerm } where encodeTerm NodeToNodeV_1 = CBOR.TInt 1 + encodeTerm NodeToNodeV_2 = CBOR.TInt 2 decodeTerm (CBOR.TInt 1) = Right NodeToNodeV_1 + decodeTerm (CBOR.TInt 2) = Right NodeToNodeV_2 decodeTerm (CBOR.TInt n) = Left ( T.pack "decode NodeToNodeVersion: unknown tag: " <> T.pack (show n) , Just n @@ -113,6 +117,7 @@ nodeToNodeCodecCBORTerm :: NodeToNodeVersion nodeToNodeCodecCBORTerm = \case NodeToNodeV_1 -> v1 + NodeToNodeV_2 -> v1 where v1 = CodecCBORTerm { encodeTerm = encodeTerm1, decodeTerm = decodeTerm1 } diff --git a/dmq-node/src/DMQ/Protocol/SigSubmission/Codec.hs b/dmq-node/src/DMQ/Protocol/SigSubmission/Codec.hs index aa09d11..50878dc 100644 --- a/dmq-node/src/DMQ/Protocol/SigSubmission/Codec.hs +++ b/dmq-node/src/DMQ/Protocol/SigSubmission/Codec.hs @@ -38,11 +38,11 @@ import Cardano.Crypto.KES.Class (decodeSigKES, decodeVerKeyKES, encodeVerKeyKES) import Cardano.KESAgent.KES.Crypto (Crypto (..)) import Cardano.KESAgent.KES.OCert (OCert (..)) -import DMQ.Protocol.SigSubmission.Type import Ouroboros.Network.Protocol.Codec.Utils qualified as Utils import Ouroboros.Network.Protocol.Limits import Ouroboros.Network.Protocol.TxSubmission2.Codec qualified as TX +import DMQ.Protocol.SigSubmission.Type -- | 'SigSubmission' time limits. @@ -135,11 +135,10 @@ codecSigSubmission ) => AnnotatedCodec (SigSubmission crypto) CBOR.DeserialiseFailure m ByteString codecSigSubmission = - TX.anncodecTxSubmission2' - SigWithBytes - encodeSigId decodeSigId - encodeSig decodeSig - + TX.anncodecTxSubmission2' + SigWithBytes + encodeSigId decodeSigId + encodeSig decodeSig encodeSig :: Sig crypto -> CBOR.Encoding encodeSig = Utils.encodeBytes . sigRawBytes diff --git a/dmq-node/src/DMQ/Protocol/SigSubmissionV2/Codec.hs b/dmq-node/src/DMQ/Protocol/SigSubmissionV2/Codec.hs new file mode 100644 index 0000000..3c7a7c9 --- /dev/null +++ b/dmq-node/src/DMQ/Protocol/SigSubmissionV2/Codec.hs @@ -0,0 +1,452 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} + +module DMQ.Protocol.SigSubmissionV2.Codec + ( codecSigSubmissionV2 + , codecSigSubmissionV2Id + , byteLimitsSigSubmissionV2 + , timeLimitsSigSubmissionV2 + , encodeSigSubmissionV2 + , anncodecSigSubmissionV2 + , anncodecSigSubmissionV2' + ) where + +import Control.Monad.Class.MonadST +import Control.Monad.Class.MonadTime.SI +import Data.ByteString.Lazy (ByteString) +import Data.Kind (Type) +import Data.List.NonEmpty qualified as NonEmpty +import Text.Printf + +import Codec.CBOR.Decoding qualified as CBOR +import Codec.CBOR.Encoding qualified as CBOR +import Codec.CBOR.Read qualified as CBOR + +import Network.TypedProtocol.Codec.CBOR + +import Ouroboros.Network.Protocol.Codec.Utils (WithByteSpan (..)) +import Ouroboros.Network.Protocol.Codec.Utils qualified as Utils +import Ouroboros.Network.Protocol.Limits + +import Cardano.KESAgent.KES.Crypto (Crypto (..)) + +import DMQ.Protocol.SigSubmissionV2.Type +import DMQ.Protocol.SigSubmission.Codec qualified as V1 + +-- | Byte Limits. +byteLimitsSigSubmissionV2 + :: forall bytes sigId sig. + (bytes -> Word) + -> ProtocolSizeLimits (SigSubmissionV2 sigId sig) bytes +byteLimitsSigSubmissionV2 = ProtocolSizeLimits stateToLimit + where + stateToLimit + :: forall (st :: SigSubmissionV2 sigId sig). + ActiveState st + => StateToken st + -> Word + stateToLimit (SingSigIds SingBlocking) = largeByteLimit + stateToLimit (SingSigIds SingNonBlocking) = largeByteLimit + stateToLimit SingSigs = largeByteLimit + stateToLimit SingIdle = smallByteLimit + stateToLimit a@SingDone = notActiveState a + +-- | 'SigSubmissionV2' time limits. +-- +-- +---------------------------------+---------------+ +-- | 'SigSubmissionV2' state | timeout (s) | +-- +=================================+===============+ +-- | `StIdle` | `waitForever` | +-- +---------------------------------+---------------+ +-- | @'StSigIds' 'StBlocking'@ | `Just 20` | +-- +---------------------------------+---------------+ +-- | @'StOSigIds' 'StNonBlocking'@ | `shortWait` | +-- +---------------------------------+---------------+ +-- | `StObjects` | `shortWait` | +-- +---------------------------------+---------------+ +timeLimitsSigSubmissionV2 + :: forall (sigId :: Type) (sig :: Type). + ProtocolTimeLimits (SigSubmissionV2 sigId sig) +timeLimitsSigSubmissionV2 = ProtocolTimeLimits stateToLimit + where + stateToLimit + :: forall (st :: SigSubmissionV2 sigId sig). + ActiveState st + => StateToken st + -> Maybe DiffTime + stateToLimit (SingSigIds SingBlocking) = Just 20 + stateToLimit (SingSigIds SingNonBlocking) = shortWait + stateToLimit SingSigs = shortWait + stateToLimit SingIdle = waitForever + stateToLimit a@SingDone = notActiveState a + + +codecSigSubmissionV2 + :: forall (sigId :: Type) (sig :: Type) m. + MonadST m + => (sigId -> CBOR.Encoding) -- ^ encode `sigId` + -> (forall s. CBOR.Decoder s sigId) -- ^ decode `sigId` + -> (sig -> CBOR.Encoding) -- ^ encode `sig` + -> (forall s. CBOR.Decoder s sig) -- ^ decode `sig` + -> Codec (SigSubmissionV2 sigId sig) CBOR.DeserialiseFailure m ByteString +codecSigSubmissionV2 + encodeSigId decodeSigId + encodeSig decodeSig + = + mkCodecCborLazyBS + (encodeSigSubmissionV2 encodeSigId encodeSig) + decode + where + decode + :: forall (st :: SigSubmissionV2 sigId sig). + ActiveState st + => StateToken st + -> forall s. CBOR.Decoder s (SomeMessage st) + decode stok = do + len <- CBOR.decodeListLen + key <- CBOR.decodeWord + decodeSigSubmissionV2 decodeSigId decodeSig stok len key + + +encodeSigSubmissionV2 + :: forall (sigId :: Type) (sig :: Type) + (st :: SigSubmissionV2 sigId sig) + (st' :: SigSubmissionV2 sigId sig). + (sigId -> CBOR.Encoding) -- ^ encode 'sigId' + -> (sig -> CBOR.Encoding) -- ^ encode 'sig' + -> Message (SigSubmissionV2 sigId sig) st st' + -> CBOR.Encoding +encodeSigSubmissionV2 encodeObjectId encodeObject = encode + where + encode + :: forall st0 st1. + Message (SigSubmissionV2 sigId sig) st0 st1 + -> CBOR.Encoding + encode (MsgRequestSigIds blocking (NumIdsAck ackNo) (NumIdsReq reqNo)) = + CBOR.encodeListLen 4 + <> CBOR.encodeWord 1 + <> CBOR.encodeBool + ( case blocking of + SingBlocking -> True + SingNonBlocking -> False + ) + <> CBOR.encodeWord16 ackNo + <> CBOR.encodeWord16 reqNo + + encode (MsgReplySigIds objIds) = + CBOR.encodeListLen 2 + <> CBOR.encodeWord 2 + <> CBOR.encodeListLenIndef + <> foldr (\(sigid, SizeInBytes sz) r -> + CBOR.encodeListLen 2 + <> encodeObjectId sigid + <> CBOR.encodeWord32 sz + <> r) CBOR.encodeBreak objIds + + encode MsgReplyNoSigIds = + CBOR.encodeListLen 1 + <> CBOR.encodeWord 3 + + encode (MsgRequestSigs objIds) = + CBOR.encodeListLen 2 + <> CBOR.encodeWord 4 + <> CBOR.encodeListLenIndef + <> foldMap encodeObjectId objIds + <> CBOR.encodeBreak + + encode (MsgReplySigs objects) = + CBOR.encodeListLen 2 + <> CBOR.encodeWord 5 + <> CBOR.encodeListLenIndef + <> foldMap encodeObject objects + <> CBOR.encodeBreak + + encode MsgDone = + CBOR.encodeListLen 1 + <> CBOR.encodeWord 6 + + +decodeSigSubmissionV2 + :: forall (sigId :: Type) (sig :: Type) + (st :: SigSubmissionV2 sigId sig) s. + ActiveState st + => (forall s'. CBOR.Decoder s' sigId) -- ^ decode 'sigId' + -> (forall s'. CBOR.Decoder s' sig) -- ^ decode sig + -> StateToken st + -> Int + -> Word + -> CBOR.Decoder s (SomeMessage st) +decodeSigSubmissionV2 decodeSigId decodeSig = decode + where + decode + :: forall (st' :: SigSubmissionV2 sigId sig). + ActiveState st' + => StateToken st' + -> Int + -> Word + -> CBOR.Decoder s (SomeMessage st') + decode stok len key = do + case (stok, len, key) of + (SingIdle, 4, 1) -> do + blocking <- CBOR.decodeBool + ackNo <- NumIdsAck <$> CBOR.decodeWord16 + reqNo <- NumIdsReq <$> CBOR.decodeWord16 + return $! if blocking + then SomeMessage $ MsgRequestSigIds SingBlocking ackNo reqNo + else SomeMessage $ MsgRequestSigIds SingNonBlocking ackNo reqNo + + (SingSigIds b, 2, 2) -> do + CBOR.decodeListLenIndef + sigIds <- CBOR.decodeSequenceLenIndef + (flip (:)) + [] + reverse + (do CBOR.decodeListLenOf 2 + sigid <- decodeSigId + sz <- CBOR.decodeWord32 + return (sigid, SizeInBytes sz)) + case (b, sigIds) of + (SingBlocking, t : ts) -> + return + $ SomeMessage + $ MsgReplySigIds (BlockingReply (t NonEmpty.:| ts)) + + (SingNonBlocking, ts) -> + return + $ SomeMessage + $ MsgReplySigIds (NonBlockingReply ts) + + (SingBlocking, []) -> + fail "codecSigSubmissionV2: MsgReplySigIds: empty list not permitted" + + (SingSigIds SingBlocking, 1, 3) -> + return (SomeMessage MsgReplyNoSigIds) + + (SingIdle, 2, 4) -> do + CBOR.decodeListLenIndef + sigIds <- CBOR.decodeSequenceLenIndef + (flip (:)) + [] + reverse + decodeSigId + return $ SomeMessage $ MsgRequestSigs sigIds + + (SingSigs, 2, 5) -> do + CBOR.decodeListLenIndef + sigs <- CBOR.decodeSequenceLenIndef + (flip (:)) + [] + reverse + decodeSig + return $ SomeMessage $ MsgReplySigs sigs + + (SingIdle, 1, 6) -> + return $ SomeMessage MsgDone + + (SingDone, _, _) -> notActiveState stok + + -- failures + (_, _, _) -> + fail $ printf "codecSigSubmissionV2 (%s) unexpected key %d, length %d" (show stok) key len + + +codecSigSubmissionV2Id + :: forall sigId sig m. + Monad m + => Codec + (SigSubmissionV2 sigId sig) + CodecFailure + m + (AnyMessage (SigSubmissionV2 sigId sig)) +codecSigSubmissionV2Id = Codec {encode, decode} + where + encode + :: forall st st'. + ( ActiveState st + , StateTokenI st + ) + => Message (SigSubmissionV2 sigId sig) st st' + -> AnyMessage (SigSubmissionV2 sigId sig) + encode = AnyMessage + + decode + :: forall (st :: SigSubmissionV2 sigId sig). + ActiveState st + => StateToken st + -> m (DecodeStep + (AnyMessage (SigSubmissionV2 sigId sig)) + CodecFailure + m + (SomeMessage st) + ) + decode stok = return $ DecodePartial $ \bytes -> + return $ case (stok, bytes) of + (SingIdle, Just (AnyMessage msg@(MsgRequestSigIds SingBlocking _ _))) -> + DecodeDone (SomeMessage msg) Nothing + (SingIdle, Just (AnyMessage msg@(MsgRequestSigIds SingNonBlocking _ _))) -> + DecodeDone (SomeMessage msg) Nothing + (SingIdle, Just (AnyMessage msg@(MsgRequestSigs {}))) -> + DecodeDone (SomeMessage msg) Nothing + (SingSigs, Just (AnyMessage msg@(MsgReplySigs {}))) -> + DecodeDone (SomeMessage msg) Nothing + (SingSigIds b, Just (AnyMessage msg)) -> case (b, msg) of + (SingBlocking, MsgReplySigIds (BlockingReply {})) -> + DecodeDone (SomeMessage msg) Nothing + (SingBlocking, MsgReplyNoSigIds) -> + DecodeDone (SomeMessage msg) Nothing + (SingNonBlocking, MsgReplySigIds (NonBlockingReply {})) -> + DecodeDone (SomeMessage msg) Nothing + (_, _) -> + DecodeFail $ CodecFailure "codecSigSubmissionV2Id: no matching message" + (SingIdle, Just (AnyMessage msg@MsgDone)) -> + DecodeDone (SomeMessage msg) Nothing + (SingDone, _) -> + notActiveState stok + (_, _) -> + DecodeFail $ CodecFailure "codecSigSubmissionV2Id: no matching message" + + +-- | An 'AnnotatedCodec' with a custom `sigWithBytes` wrapper of `sig`, +-- e.g. `sigWithBytes ~ WithBytes sig`. +-- +anncodecSigSubmissionV2 + :: forall (sigId :: Type) (sig :: Type) (sigWithBytes :: Type) m. + MonadST m + => (ByteString -> sig -> sigWithBytes) + -- ^ `withBytes` constructor + -> (sigId -> CBOR.Encoding) + -- ^ encode 'sigid' + -> (forall s . CBOR.Decoder s sigId) + -- ^ decode 'sigid' + -> (sigWithBytes -> CBOR.Encoding) + -- ^ encode `sig` + -> (forall s . CBOR.Decoder s (ByteString -> sig)) + -- ^ decode signature + -> AnnotatedCodec (SigSubmissionV2 sigId sigWithBytes) CBOR.DeserialiseFailure m ByteString +anncodecSigSubmissionV2 mkWithBytes encodeSigId decodeSigId + encodeSig decodeSig = + mkCodecCborLazyBS + (encodeSigSubmissionV2 encodeSigId encodeSig) + decode + where + decode :: forall (st :: SigSubmissionV2 sigId sigWithBytes). + ActiveState st + => StateToken st + -> forall s. CBOR.Decoder s (Annotator ByteString st) + decode = + decodeSigSubmissionV2' @sig + @sigWithBytes + @WithByteSpan + @ByteString + mkWithBytes' + decodeSigId + (Utils.decodeWithByteSpan decodeSig) + + mkWithBytes' :: ByteString + -> WithByteSpan (ByteString -> sig) + -> sigWithBytes + mkWithBytes' bytes (WithByteSpan (fn, start, end)) = + mkWithBytes (Utils.bytesBetweenOffsets start end bytes) -- bytes of the transaction + (fn bytes) -- note: fn expects full bytes + + +decodeSigSubmissionV2' + :: forall (sig :: Type) + (sigWithBytes :: Type) + (withByteSpan :: Type -> Type) + (bytes :: Type) + (sigId :: Type) + (st :: SigSubmissionV2 sigId sigWithBytes) + s. + ActiveState st + => (bytes -> withByteSpan (bytes -> sig) -> sigWithBytes) + -> (forall s'. CBOR.Decoder s' sigId) -- ^ decode 'sigId' + -> (forall s'. CBOR.Decoder s' (withByteSpan (bytes -> sig))) + -> StateToken st + -> CBOR.Decoder s (Annotator bytes st) +decodeSigSubmissionV2' mkWithBytes decodeSigId decodeSig sok = do + len <- CBOR.decodeListLen + key <- CBOR.decodeWord + decode sok len key + where + decode stok len key = do + case (stok, len, key) of + (SingIdle, 4, 1) -> do + blocking <- CBOR.decodeBool + ackNo <- NumIdsAck <$> CBOR.decodeWord16 + reqNo <- NumIdsReq <$> CBOR.decodeWord16 + return $! if blocking + then Annotator $ \_ -> SomeMessage $ MsgRequestSigIds SingBlocking ackNo reqNo + else Annotator $ \_ -> SomeMessage $ MsgRequestSigIds SingNonBlocking ackNo reqNo + + (SingSigIds b, 2, 2) -> do + CBOR.decodeListLenIndef + sigIds <- CBOR.decodeSequenceLenIndef + (flip (:)) + [] + reverse + (do CBOR.decodeListLenOf 2 + sigid <- decodeSigId + sz <- CBOR.decodeWord32 + return (sigid, SizeInBytes sz)) + case (b, sigIds) of + (SingBlocking, t : ts) -> + return + $ Annotator $ \_ -> SomeMessage $ MsgReplySigIds (BlockingReply (t NonEmpty.:| ts)) + + (SingNonBlocking, ts) -> + return + $ Annotator $ \_ -> SomeMessage $ MsgReplySigIds (NonBlockingReply ts) + + (SingBlocking, []) -> + fail "codecSigSubmissionV2: MsgReplySigIds: empty list not permitted" + + (SingSigIds SingBlocking, 1, 3) -> + return (Annotator $ \_ -> SomeMessage MsgReplyNoSigIds) + + (SingIdle, 2, 4) -> do + CBOR.decodeListLenIndef + sigIds <- CBOR.decodeSequenceLenIndef + (flip (:)) + [] + reverse + decodeSigId + return $ Annotator $ \_ -> SomeMessage $ MsgRequestSigs sigIds + + (SingSigs, 2, 5) -> do + CBOR.decodeListLenIndef + sigs <- CBOR.decodeSequenceLenIndef + (flip (:)) + [] + reverse + decodeSig + return (Annotator $ \bytes -> SomeMessage (MsgReplySigs $ mkWithBytes bytes <$> sigs)) + + (SingIdle, 1, 6) -> + return $ Annotator $ \_ -> SomeMessage MsgDone + + (SingDone, _, _) -> notActiveState stok + + -- failures + (_, _, _) -> + fail $ printf "codecSigSubmissionV2 (%s) unexpected key %d, length %d" (show stok) key len + + +anncodecSigSubmissionV2' + :: forall crypto m. + ( Crypto crypto + , MonadST m + ) + => AnnotatedCodec (SigSubmissionV2 SigId (Sig crypto)) CBOR.DeserialiseFailure m ByteString +anncodecSigSubmissionV2' = + anncodecSigSubmissionV2 + SigWithBytes + V1.encodeSigId V1.decodeSigId + V1.encodeSig V1.decodeSig diff --git a/dmq-node/src/DMQ/Protocol/SigSubmissionV2/Inbound.hs b/dmq-node/src/DMQ/Protocol/SigSubmissionV2/Inbound.hs new file mode 100644 index 0000000..e1f1db6 --- /dev/null +++ b/dmq-node/src/DMQ/Protocol/SigSubmissionV2/Inbound.hs @@ -0,0 +1,126 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} + +-- | A view of the sig submission protocol from the point of view of the +-- inbound/client peer. +-- +-- This provides a view that uses less complex types and should be easier to use +-- than the underlying typed protocol itself. +-- +-- For execution, a conversion into the typed protocol is provided. +module DMQ.Protocol.SigSubmissionV2.Inbound + ( -- * Protocol type for the inbound + SigSubmissionInboundPipelined (..) + , InboundStIdle (..) + , Collect (..) + -- * Execution as a typed protocol + , sigSubmissionV2InboundPeerPipelined + ) where + +import Data.List.NonEmpty qualified as NonEmpty +import Data.Map.Strict (Map) +import Data.Map.Strict qualified as Map +import Network.TypedProtocol.Core +import Network.TypedProtocol.Peer (Peer, PeerPipelined (..)) +import Network.TypedProtocol.Peer.Client +import DMQ.Protocol.SigSubmissionV2.Type + +data SigSubmissionInboundPipelined sigId sig m a where + SigSubmissionInboundPipelined + :: m (InboundStIdle Z sigId sig m a) + -> SigSubmissionInboundPipelined sigId sig m a + +-- | This is the type of the pipelined results, collected by 'CollectPipelined'. +-- This protocol can pipeline requests for identifiers and signatures, so we use +-- a sum of either for collecting the responses. +-- +data Collect sigId sig + = -- | The result of 'SendMsgRequestSigIdsPipelined'. It also carries + -- the number of sigIds originally requested. + CollectSigIds NumIdsReq [(sigId, SizeInBytes)] + + | -- | The result of 'SendMsgRequestSigsPipelined'. The actual reply only + -- contains the signatures sent, but this pairs them up with the + -- requested identifiers. This is for the peer to determine whether some + -- signatures are no longer needed. + CollectSigs (Map sigId SizeInBytes) [sig] + + +data InboundStIdle (n :: N) sigId sig m a where + SendMsgRequestSigIdsBlocking + :: NumIdsAck -- ^ number of sigIds to acknowledge + -> NumIdsReq -- ^ number of sigIds to request + -> ([(sigId, SizeInBytes)] -> m (InboundStIdle Z sigId sig m a)) + -> InboundStIdle Z sigId sig m a + + SendMsgRequestSigIdsPipelined + :: NumIdsAck + -> NumIdsReq + -> m (InboundStIdle (S n) sigId sig m a) + -> InboundStIdle n sigId sig m a + + SendMsgRequestSigsPipelined + :: Map sigId SizeInBytes + -> m (InboundStIdle (S n) sigId sig m a) + -> InboundStIdle n sigId sig m a + + CollectPipelined + :: Maybe (InboundStIdle (S n) sigId sig m a) + -> (Collect sigId sig -> m (InboundStIdle n sigId sig m a)) + -> InboundStIdle (S n) sigId sig m a + + SendMsgDone + :: m a + -> InboundStIdle Z sigId sig m a + + +-- | Transform a 'SigSubmissionInboundPipelined' into a 'PeerPipelined'. +-- +sigSubmissionV2InboundPeerPipelined + :: forall sigId sig m a. + (Functor m) + => SigSubmissionInboundPipelined sigId sig m a + -> PeerPipelined (SigSubmissionV2 sigId sig) AsClient StIdle m a +sigSubmissionV2InboundPeerPipelined (SigSubmissionInboundPipelined inboundSt) = + PeerPipelined $ Effect (run <$> inboundSt) + where + run :: InboundStIdle n sigId sig m a + -> Peer (SigSubmissionV2 sigId sig) AsClient (Pipelined n (Collect sigId sig)) StIdle m a + + run (SendMsgRequestSigIdsBlocking ackNo reqNo k) = + Yield (MsgRequestSigIds SingBlocking ackNo reqNo) $ + Await \case + MsgReplySigIds (BlockingReply sigIds) -> + Effect $ run <$> k (NonEmpty.toList sigIds) + + MsgReplyNoSigIds -> + Effect $ run <$> k [] + + run (SendMsgRequestSigIdsPipelined ackNo reqNo k) = + YieldPipelined + (MsgRequestSigIds SingNonBlocking ackNo reqNo) + (ReceiverAwait + $ \(MsgReplySigIds (NonBlockingReply sigIds)) -> + ReceiverDone (CollectSigIds reqNo sigIds) + ) + (Effect $ run <$> k) + + run (SendMsgRequestSigsPipelined sigIds k) = + YieldPipelined + (MsgRequestSigs $ Map.keys sigIds) + (ReceiverAwait + $ \(MsgReplySigs sigs) -> + ReceiverDone (CollectSigs sigIds sigs) + ) + (Effect $ run <$> k) + + run (CollectPipelined none collect) = + Collect + (run <$> none) + (Effect . fmap run . collect) + + run (SendMsgDone done) = + Effect $ Yield MsgDone . Done <$> done diff --git a/dmq-node/src/DMQ/Protocol/SigSubmissionV2/Outbound.hs b/dmq-node/src/DMQ/Protocol/SigSubmissionV2/Outbound.hs new file mode 100644 index 0000000..f76b04a --- /dev/null +++ b/dmq-node/src/DMQ/Protocol/SigSubmissionV2/Outbound.hs @@ -0,0 +1,111 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} + +-- | A view of the sig diffusion protocol from the point of view of +-- the outbound/server peer. +-- +-- This provides a view that uses less complex types and should be easier to +-- use than the underlying typed protocol itself. +-- +-- For execution, 'sigSubmissionOutboundPeer' is provided for conversion +-- into the typed protocol. +module DMQ.Protocol.SigSubmissionV2.Outbound + ( -- * Protocol type for the outbound + SigSubmissionOutbound (..) + , OutboundStIdle (..) + , OutboundStSigIds (..) + , OutboundStSigs (..) + -- * Execution as a typed protocol + , sigSubmissionV2OutboundPeer + ) where + +import Data.Functor ((<&>)) +import Data.Singletons (SingI) +import Network.TypedProtocol.Core +import Network.TypedProtocol.Peer (Peer) +import Network.TypedProtocol.Peer.Server + +import DMQ.Protocol.SigSubmissionV2.Type + +-- | The outbound side of the sig diffusion protocol. +-- +-- The peer in the outbound/server role submits sigs to the peer in the +-- inbound/client role. +newtype SigSubmissionOutbound sigId sig m a = SigSubmissionOutbound { + runSigSubmissionOutbound :: m (OutboundStIdle sigId sig m a) + } + +-- | In the 'StIdle' protocol state, the outbound does not have agency. Instead +-- it is waiting for: +-- +-- * a request for sig ids (blocking or non-blocking) +-- * a request for a given list of sigs +-- * a termination message +-- +-- It must be prepared to handle any of these. +data OutboundStIdle sigId sig m a = OutboundStIdle { + recvMsgRequestSigIds :: forall blocking. + SingBlockingStyle blocking + -> NumIdsAck + -> NumIdsReq + -> m (OutboundStSigIds blocking sigId sig m a), + + recvMsgRequestSigs :: [sigId] + -> m (OutboundStSigs sigId sig m a) + } + +data OutboundStSigIds blocking sigId sig m a where + SendMsgReplySigIds + :: SingI blocking + => BlockingReplyList blocking (sigId, SizeInBytes) + -> OutboundStIdle sigId sig m a + -> OutboundStSigIds blocking sigId sig m a + + SendMsgReplyNoSigIds + :: OutboundStIdle sigId sig m a + -> OutboundStSigIds StBlocking sigId sig m a + +data OutboundStSigs sigId sig m a where + SendMsgReplySigs + :: [sig] + -> OutboundStIdle sigId sig m a + -> OutboundStSigs sigId sig m a + + +-- | A non-pipelined 'Peer' representing the 'SigSubmissionOutbound'. +sigSubmissionV2OutboundPeer + :: forall sigId sig m a. + Monad m + => SigSubmissionOutbound sigId sig m a + -> Peer (SigSubmissionV2 sigId sig) AsServer NonPipelined StIdle m a +sigSubmissionV2OutboundPeer (SigSubmissionOutbound outboundSt) = + Effect (run <$> outboundSt) + where + run :: OutboundStIdle sigId sig m a + -> Peer (SigSubmissionV2 sigId sig) AsServer NonPipelined StIdle m a + run OutboundStIdle {recvMsgRequestSigIds, recvMsgRequestSigs} = + Await $ \case + MsgRequestSigIds blocking ackNo reqNo -> Effect $ do + recvMsgRequestSigIds blocking ackNo reqNo <&> \case + + SendMsgReplySigIds sigIds k -> + Yield + (MsgReplySigIds sigIds) + (run k) + + SendMsgReplyNoSigIds k -> + Yield + MsgReplyNoSigIds + (run k) + + MsgRequestSigs sigIds -> Effect $ do + recvMsgRequestSigs sigIds <&> \case + SendMsgReplySigs sigs k -> + Yield + (MsgReplySigs sigs) + (run k) diff --git a/dmq-node/src/DMQ/Protocol/SigSubmissionV2/Type.hs b/dmq-node/src/DMQ/Protocol/SigSubmissionV2/Type.hs new file mode 100644 index 0000000..368f133 --- /dev/null +++ b/dmq-node/src/DMQ/Protocol/SigSubmissionV2/Type.hs @@ -0,0 +1,337 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeFamilies #-} + +-- | The type of the signature diffusion protocol. +-- +-- This is used to diffuse generic signatures between nodes. +-- +-- It is based on `Ouroboros.Network.Protocol.ObjectDiffusion` mini-protocol +-- originally designed for Peras. +-- +module DMQ.Protocol.SigSubmissionV2.Type + ( SigSubmissionV2 (..) + , Message (..) + , SingSigSubmissionV2 (..) + , NumIdsAck (..) + , NumIdsReq (..) + , NumReq (..) + , NumUnacknowledged (..) + -- Signature types + , module SigSubmission + -- re-exports + , BlockingReplyList (..) + , SingBlockingStyle (..) + , SizeInBytes (..) + , StBlockingStyle (..) + ) where + +import Control.DeepSeq (NFData (..)) +import Data.Aeson (ToJSON (toJSON), Value (String), KeyValue ((.=)), object) +import Data.Kind (Type) +import Data.Monoid (Sum (..)) +import Data.Singletons +import Data.Text (pack) +import Data.Word (Word16) +import GHC.Generics (Generic) +import NoThunks.Class (NoThunks (..)) +import Quiet (Quiet (..)) + +import Network.TypedProtocol.Codec (AnyMessage(AnyMessageAndAgency)) +import Network.TypedProtocol.Core + +import DMQ.Protocol.SigSubmission.Type as SigSubmission (SigId (..), + SigBody (..), SigKESSignature (..), SigOpCertificate (..), + SigColdKey (..), SigRaw (..), SigRawWithSignedBytes (..), Sig (..)) +import Ouroboros.Network.Protocol.TxSubmission2.Type (BlockingReplyList (..), + SingBlockingStyle (..), StBlockingStyle (..)) +import Ouroboros.Network.SizeInBytes (SizeInBytes (..)) +import Ouroboros.Network.Util.ShowProxy (ShowProxy (..)) + +-- | The kind of the object diffusion protocol, and the types of the states in +-- the protocol state machine. +-- +-- We describe this protocol using indiscriminately the labels \"inbound\"/\"client\" +-- for the peer that is receiving objects, and \"outbound\"/\"server\" for the one +-- sending them. +type SigSubmissionV2 :: Type -> Type -> Type +data SigSubmissionV2 sigId sig where + -- | The inbound node has agency; it can either terminate, ask for object + -- identifiers or ask for objects. + -- + -- There is no timeout in this state. + StIdle :: SigSubmissionV2 sigId sig + + -- | The outbound node has agency; it must reply with a list of object + -- identifiers that it wishes to submit. + -- + -- There are two sub-states for this, for blocking and non-blocking cases. + StSigIds :: StBlockingStyle -> SigSubmissionV2 sigId sig + + -- | The outbound node has agency; it must reply with the list of + -- objects. + StSigs :: SigSubmissionV2 sigId sig + + -- | Nobody has agency; termination state. + StDone :: SigSubmissionV2 sigId sig + +instance ( ShowProxy sigId + , ShowProxy sig + ) + => ShowProxy (SigSubmissionV2 sigId sig) where + showProxy _ = + concat + [ "SigSubmissionV2 ", + showProxy (Proxy :: Proxy sigId), + " ", + showProxy (Proxy :: Proxy sig) + ] + +instance ShowProxy (StIdle :: SigSubmissionV2 sigId sig) where + showProxy _ = "StIdle" +instance (Show sigId, Show sig) + => ToJSON (AnyMessage (SigSubmissionV2 sigId sig)) where + toJSON (AnyMessageAndAgency stok MsgRequestSigIds{}) = + object + [ "kind" .= String "MsgRequestSigIds" + , "agency" .= String (pack $ show stok) + ] + toJSON (AnyMessageAndAgency stok (MsgReplySigIds ids)) = + object + [ "kind" .= String "MsgReplySigIds" + , "agency" .= String (pack $ show stok) + , "ids" .= String (pack $ show ids) + ] + toJSON (AnyMessageAndAgency stok MsgReplyNoSigIds) = + object + [ "kind" .= String "MsgReplyNoSigIds" + , "agency" .= String (pack $ show stok) + ] + toJSON (AnyMessageAndAgency stok (MsgRequestSigs{})) = + object + [ "kind" .= String "MsgRequestSigs" + , "agency" .= String (pack $ show stok) + ] + toJSON (AnyMessageAndAgency stok (MsgReplySigs sigs)) = + object + [ "kind" .= String "MsgReplySigs" + , "agency" .= String (pack $ show stok) + , "sigs" .= String (pack $ show sigs) + ] + toJSON (AnyMessageAndAgency stok MsgDone) = + object + [ "kind" .= String "MsgDone" + , "agency" .= String (pack $ show stok) + ] + + +type SingSigSubmissionV2 + :: SigSubmissionV2 sigId sig + -> Type +data SingSigSubmissionV2 k where + SingIdle :: SingSigSubmissionV2 StIdle + SingSigIds :: SingBlockingStyle stBlocking + -> SingSigSubmissionV2 (StSigIds stBlocking) + SingSigs :: SingSigSubmissionV2 StSigs + SingDone :: SingSigSubmissionV2 StDone + +deriving instance Show (SingSigSubmissionV2 st) + +instance StateTokenI StIdle where stateToken = SingIdle +instance SingI stBlocking + => StateTokenI (StSigIds stBlocking) where stateToken = SingSigIds sing +instance StateTokenI StSigs where stateToken = SingSigs +instance StateTokenI StDone where stateToken = SingDone + + +newtype NumIdsAck = NumIdsAck {getNumIdsAck :: Word16} + deriving (Eq, Ord, NFData, Generic) + deriving newtype (Num, Enum, Real, Integral, Bounded, NoThunks) + deriving Semigroup via (Sum Word16) + deriving Monoid via (Sum Word16) + deriving Show via (Quiet NumIdsAck) + +newtype NumIdsReq = NumIdsReq {getNumIdsReq :: Word16} + deriving (Eq, Ord, NFData, Generic) + deriving newtype (Num, Enum, Real, Integral, Bounded, NoThunks) + deriving Semigroup via (Sum Word16) + deriving Monoid via (Sum Word16) + deriving Show via (Quiet NumIdsReq) + +newtype NumReq = NumReq {getNumReq :: Word16} + deriving (Eq, Ord, NFData, Generic) + deriving newtype (Num, Enum, Real, Integral, Bounded, NoThunks) + deriving Semigroup via (Sum Word16) + deriving Monoid via (Sum Word16) + deriving Show via (Quiet NumReq) + +newtype NumUnacknowledged = NumUnacknowledged {getNumUnacknowledged :: Word16} + deriving (Eq, Ord, NFData, Generic) + deriving newtype (Num, Enum, Real, Integral, Bounded, NoThunks) + deriving Semigroup via (Sum Word16) + deriving Monoid via (Sum Word16) + deriving Show via (Quiet NumUnacknowledged) + + +-- | There are some constraints of the protocol that are not captured in the +-- types of the messages, but are documented with the messages. Violation +-- of these constraints is also a protocol error. The constraints are intended +-- to ensure that implementations are able to work in bounded space. +instance Protocol (SigSubmissionV2 sigId sig) where + -- | The messages in the object diffusion protocol. + -- + -- In this protocol the consumer (inbound side, client role) always + -- initiates and the producer (outbound side, server role) replies. + -- This makes it a pull based protocol where the receiver manages the + -- control flow. + -- + -- The protocol involves asking for object identifiers, and then + -- asking for objects corresponding to the identifiers of interest. + -- + -- There are two ways to ask for object identifiers, blocking and + -- non-blocking. They otherwise have the same semantics. + -- + -- The protocol maintains a notional FIFO of "outstanding" object + -- identifiers that have been provided but not yet acknowledged. Only + -- objects that are outstanding can be requested: they can be + -- requested in any order, but at most once. Object identifiers are + -- acknowledged in the same FIFO order they were provided in. The + -- acknowledgement is included in the same messages used to ask for more + -- object identifiers. + data Message (SigSubmissionV2 sigId sig) from to where + + -- | Request a list of identifiers from the server, and confirm a + -- number of outstanding identifiers. + -- + -- With 'TokBlocking' this is a blocking operation but it's not guaranteed + -- that the server will respond with signatures. The server might block for + -- only a limited time waiting for signaures, if it times out it will reply + -- with `MsgReplyNoSigs` to let the client regain control of the protocol. + -- + -- With 'TokNonBlocking' this is a non-blocking operation: the response may + -- be an empty list and this does expect a prompt response. This covers high + -- throughput use cases where we wish to pipeline, by interleaving requests + -- for additional identifiers with requests for signatures, which + -- requires these requests not block. + -- + -- The request gives the maximum number of identifiers that can be + -- accepted in the response. This must be greater than zero in the + -- 'TokBlocking' case. In the 'TokNonBlocking' case either the numbers + -- acknowledged or the number requested __MUST__ be non-zero. In either + -- case, the number requested __MUST__ not put the total outstanding over + -- the fixed protocol limit. + -- + -- The request also gives the number of outstanding identifiers that + -- can now be acknowledged. The actual signatures to acknowledge are known + -- to the server based on the FIFO order in which they were provided. + -- + -- There is no choice about when to use the blocking case versus the + -- non-blocking case, it depends on whether there are any remaining + -- unacknowledged signatures (after taking into account the ones + -- acknowledged in this message): + -- + -- * The blocking case __MUST__ be used when there are zero remaining + -- unacknowledged signatures. + -- + -- * The non-blocking case __MUST__ be used when there are non-zero + -- remaining unacknowledged signatures. + -- + MsgRequestSigIds + :: forall (blocking :: StBlockingStyle) sigId sig. + SingBlockingStyle blocking + -> NumIdsAck -- ^ Acknowledge this number of outstanding signatures + -> NumIdsReq -- ^ Request up to this number of identifiers + -> Message (SigSubmissionV2 sigId sig) StIdle (StSigIds blocking) + + -- | Reply with a list of object identifiers for available objects, along + -- with the size of each object. + -- + -- The list must not be longer than the maximum number requested. + -- + -- In the 'StSigIds' 'Blocking' state the list must be non-empty while in + -- the 'StSigIds' 'NonBlocking' state the list may be empty. + -- + -- These objects are added to the notional FIFO of outstanding object + -- identifiers for the protocol. + -- + -- The order in which these object identifiers are returned must be the + -- order in which they are submitted to the mempool, to preserve dependent + -- objects. + -- + MsgReplySigIds + :: BlockingReplyList blocking (sigId, SizeInBytes) + -> Message (SigSubmissionV2 sigId sig) (StSigIds blocking) StIdle + + -- | The blocking request `MsgRequestSigIds` can be replied with no + -- signatures to let the client regain the control of the protocol. + -- + MsgReplyNoSigIds + :: Message (SigSubmissionV2 sidId sig) (StSigIds StBlocking) StIdle + + -- | Request one or more objects corresponding to the given identifiers. + -- + -- While it is the responsibility of the server to keep within + -- pipelining in-flight limits, the client must also cooperate by keeping + -- the total requested across all in-flight requests within the limits. + -- + -- It is an error to ask for identifiers that were not + -- previously announced (via 'MsgReplySigIds'). + -- + -- It is an error to ask for identifiers that are not + -- outstanding or that were already asked for. + -- + MsgRequestSigs + :: [sigId] + -> Message (SigSubmissionV2 sigId sig) StIdle StSigs + + -- | Reply with the requested signatures, or implicitly discard. + -- + -- Signatures can become invalid between the time the identifier was + -- sent and the signatures being requested. Invalid (including committed) + -- signatures do not need to be sent. + -- + -- Any identifiers requested but not provided in this reply + -- should be considered as if this peer had never announced them. (Note + -- that this is no guarantee that the signature is invalid, it may still be + -- valid and available from another peer). + -- + MsgReplySigs + :: [sig] + -> Message (SigSubmissionV2 sigId sig) StSigs StIdle + + -- | Termination message, initiated by the client side when idle. + MsgDone + :: Message (SigSubmissionV2 sigId sig) StIdle StDone + + type StateAgency StIdle = ClientAgency + type StateAgency (StSigIds b) = ServerAgency + type StateAgency StSigs = ServerAgency + type StateAgency StDone = NobodyAgency + + type StateToken = SingSigSubmissionV2 + +instance ( NFData sigId + , NFData sig + ) + => NFData (Message (SigSubmissionV2 sigId sig) from to) where + rnf (MsgRequestSigIds tkbs w1 w2) = rnf tkbs `seq` rnf w1 `seq` rnf w2 + rnf (MsgReplySigIds brl) = rnf brl + rnf MsgReplyNoSigIds = () + rnf (MsgRequestSigs sigIds) = rnf sigIds + rnf (MsgReplySigs sigs) = rnf sigs + rnf MsgDone = () + +deriving instance (Eq sigId, Eq sig) + => Eq (Message (SigSubmissionV2 sigId sig) from to) + +deriving instance (Show sigId, Show sig) + => Show (Message (SigSubmissionV2 sigId sig) from to) diff --git a/dmq-node/src/DMQ/SigSubmissionV2/Inbound.hs b/dmq-node/src/DMQ/SigSubmissionV2/Inbound.hs new file mode 100644 index 0000000..fd61359 --- /dev/null +++ b/dmq-node/src/DMQ/SigSubmissionV2/Inbound.hs @@ -0,0 +1,218 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module DMQ.SigSubmissionV2.Inbound + ( -- * SigSubmision Inbound client + sigSubmissionInbound + ) where + +import Data.Map.Strict qualified as Map +import Data.Sequence.Strict qualified as StrictSeq +import Data.Set qualified as Set + +import Control.Exception (assert) +import Control.Monad (unless, when) +import Control.Monad.Class.MonadAsync (MonadAsync (..)) +import Control.Monad.Class.MonadThrow +import Control.Monad.Class.MonadTimer.SI +import Control.Tracer (Tracer, traceWith) +import Network.TypedProtocol + +import Ouroboros.Network.ControlMessage (ControlMessageSTM, + timeoutWithControlMessage) + +import Ouroboros.Network.TxSubmission.Inbound.V2.Types ( + TxSubmissionMempoolWriter (..)) + +import DMQ.Protocol.SigSubmissionV2.Inbound + +import Ouroboros.Network.TxSubmission.Inbound.V2 (TraceTxSubmissionInbound (..), + PeerTxAPI (..), TxDecision (..), TxsToMempool (..), + TxSubmissionProtocolError (..), TxSubmissionInitDelay (..)) +import Ouroboros.Network.Protocol.TxSubmission2.Type (NumTxIdsToReq(..), + NumTxIdsToAck (..)) +import DMQ.Protocol.SigSubmissionV2.Type (NumIdsReq(..), NumIdsAck (NumIdsAck)) + +-- | A sig-submission inbound side (client, sic!). +-- +-- The client blocks on receiving `SigDecision` from the decision logic. If +-- there are sig's to download it pipelines two requests: first for sig's second +-- for sigid's. If there are no sig's to download, it either sends a blocking or +-- non-blocking request for sigid's. +-- +sigSubmissionInbound + :: forall sigid sig idx m failure. + ( MonadDelay m + , MonadThrow m + , MonadAsync m + , Ord sigid + ) + => Tracer m (TraceTxSubmissionInbound sigid sig) + -> TxSubmissionInitDelay + -> TxSubmissionMempoolWriter sigid sig idx m failure + -> PeerTxAPI m sigid sig + -> ControlMessageSTM m + -> SigSubmissionInboundPipelined sigid sig m () +sigSubmissionInbound + tracer + initDelay + TxSubmissionMempoolWriter { txId } + PeerTxAPI { + readTxDecision, + handleReceivedTxIds, + handleReceivedTxs, + submitTxToMempool + } + controlMessageSTM + = + SigSubmissionInboundPipelined $ do + case initDelay of + TxSubmissionInitDelay delay -> threadDelay delay + NoTxSubmissionInitDelay -> return () + inboundIdle + where + inboundIdle + :: m (InboundStIdle Z sigid sig m ()) + inboundIdle = do + -- TODO + -- readSigDecision is blocking on next decision because takeMVar and ControlMessageSTM is blocking + sigDecision <- async readTxDecision + msigd <- timeoutWithControlMessage controlMessageSTM (waitSTM sigDecision) + case msigd of + Nothing -> pure (SendMsgDone $ return ()) + Just sigd@TxDecision + { txdTxsToRequest = sigsToRequest + , txdTxsToMempool = TxsToMempool { listOfTxsToMempool } + } -> do + traceWith tracer (TraceTxInboundDecision sigd) + + let !collected = length listOfTxsToMempool + + -- Only attempt to add sigs if we have some work to do + when (collected > 0) $ do + -- submitTxToMempool traces: + -- * `TraceTxSubmissionProcessed`, + -- * `TraceTxInboundAddedToMempool`, and + -- * `TraceTxInboundRejectedFromMempool` + -- events. + mapM_ (uncurry $ submitTxToMempool tracer) listOfTxsToMempool + + -- TODO: + -- We can update the state so that other `sig-submission` servers will + -- not try to add these sigs to the mempool. + if Map.null sigsToRequest + then serverReqSigIds Zero sigd + else serverReqSigs sigd + + + -- Pipelined request of sigs + serverReqSigs :: TxDecision sigid sig + -> m (InboundStIdle Z sigid sig m ()) + serverReqSigs sigd@TxDecision { txdTxsToRequest = sigdSigsToRequest } = + pure $ SendMsgRequestSigsPipelined sigdSigsToRequest + (serverReqSigIds (Succ Zero) sigd) + + serverReqSigIds :: forall (n :: N). + Nat n + -> TxDecision sigid sig + -> m (InboundStIdle n sigid sig m ()) + serverReqSigIds + n TxDecision { txdTxIdsToRequest = 0 } + = + case n of + Zero -> inboundIdle + Succ _ -> handleReplies n + + serverReqSigIds + -- if there are no unacknowledged sigids, the protocol requires sending + -- a blocking `MsgRequestSigIds` request. This is important, as otherwise + -- the client side wouldn't have a chance to terminate the + -- mini-protocol. + Zero TxDecision { txdTxIdsToAcknowledge = sigIdsToAck, + txdPipelineTxIds = False, + txdTxIdsToRequest = sigIdsToReq + } + = + pure $ SendMsgRequestSigIdsBlocking + (NumIdsAck . getNumTxIdsToAck $ sigIdsToAck) + (NumIdsReq . getNumTxIdsToReq $ sigIdsToReq) + (\sigids -> do + let sigidsSeq = StrictSeq.fromList $ fst <$> sigids + sigidsMap = Map.fromList sigids + unless (StrictSeq.length sigidsSeq <= fromIntegral sigIdsToReq) $ + throwIO ProtocolErrorTxIdsNotRequested + handleReceivedTxIds sigIdsToReq sigidsSeq sigidsMap + inboundIdle + ) + + serverReqSigIds + n@Zero TxDecision { txdTxIdsToAcknowledge = sigIdsToAck, + txdPipelineTxIds = True, + txdTxIdsToRequest = sigIdsToReq + } + = + pure $ SendMsgRequestSigIdsPipelined + (NumIdsAck . getNumTxIdsToAck $ sigIdsToAck) + (NumIdsReq . getNumTxIdsToReq $ sigIdsToReq) + (handleReplies (Succ n)) + + serverReqSigIds + n@Succ{} TxDecision { txdTxIdsToAcknowledge = sigIdsToAck, + txdPipelineTxIds, + txdTxIdsToRequest = sigIdsToReq + } + = + -- it is impossible that we have had `sig`'s to request (Succ{} - is an + -- evidence for that), but no unacknowledged `sigid`s. + assert txdPipelineTxIds $ + pure $ SendMsgRequestSigIdsPipelined + (NumIdsAck . getNumTxIdsToAck $ sigIdsToAck) + (NumIdsReq . getNumTxIdsToReq $ sigIdsToReq) + (handleReplies (Succ n)) + + + handleReplies :: forall (n :: N). + Nat (S n) + -> m (InboundStIdle (S n) sigid sig m ()) + handleReplies (Succ n'@Succ{}) = + pure $ CollectPipelined + Nothing + (handleReply (handleReplies n')) + + handleReplies (Succ Zero) = + pure $ CollectPipelined + Nothing + (handleReply inboundIdle) + + handleReply :: forall (n :: N). + m (InboundStIdle n sigid sig m ()) + -- continuation + -> Collect sigid sig + -> m (InboundStIdle n sigid sig m ()) + handleReply k = \case + CollectSigIds sigIdsToReq sigids -> do + let sigidsSeq = StrictSeq.fromList $ fst <$> sigids + sigidsMap = Map.fromList sigids + unless (StrictSeq.length sigidsSeq <= fromIntegral sigIdsToReq) $ + throwIO ProtocolErrorTxIdsNotRequested + handleReceivedTxIds (NumTxIdsToReq . getNumIdsReq $ sigIdsToReq) sigidsSeq sigidsMap + k + CollectSigs sigids sigs -> do + let requested = Map.keysSet sigids + received = Map.fromList [ (txId sig, sig) | sig <- sigs ] + + unless (Map.keysSet received `Set.isSubsetOf` requested) $ + throwIO ProtocolErrorTxNotRequested + + mbe <- handleReceivedTxs sigids received + traceWith tracer $ TraceTxSubmissionCollected (txId `map` sigs) + case mbe of + -- one of `sig`s had a wrong size + Just e -> traceWith tracer (TraceTxInboundError e) + >> throwIO e + Nothing -> k diff --git a/dmq-node/src/DMQ/SigSubmissionV2/Outbound.hs b/dmq-node/src/DMQ/SigSubmissionV2/Outbound.hs new file mode 100644 index 0000000..f09c13b --- /dev/null +++ b/dmq-node/src/DMQ/SigSubmissionV2/Outbound.hs @@ -0,0 +1,202 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module DMQ.SigSubmissionV2.Outbound + ( sigSubmissionOutbound + , TraceSigSubmissionOutbound (..) + , SigSubmissionProtocolError (..) + ) where + +import Data.Aeson (ToJSON (toJSON), Value (String), object, KeyValue ((.=))) +import Data.Foldable (find) +import Data.List.NonEmpty qualified as NonEmpty +import Data.Maybe (catMaybes, isNothing, mapMaybe) +import Data.Sequence.Strict (StrictSeq) +import Data.Sequence.Strict qualified as Seq +import Data.Word (Word16) + +import Control.Exception (assert) +import Control.Monad (unless, when) +import Control.Monad.Class.MonadSTM +import Control.Monad.Class.MonadThrow +import Control.Tracer (Tracer (..), traceWith) + +import Ouroboros.Network.TxSubmission.Mempool.Reader (MempoolSnapshot (..), + TxSubmissionMempoolReader (..)) + +import DMQ.Protocol.SigSubmissionV2.Outbound +import DMQ.Protocol.SigSubmissionV2.Type + + +data TraceSigSubmissionOutbound sigId sig + = TraceSigSubmissionOutboundRecvMsgRequestSigs + [sigId] + -- ^ The IDs of the signatures requested. + | TraceSigSubmissionOutboundSendMsgReplySigs + [sig] + -- ^ The sigs to be sent in the response. + deriving Show + +instance (ToJSON sigId, ToJSON sig) + => ToJSON (TraceSigSubmissionOutbound sigId sig) where + toJSON (TraceSigSubmissionOutboundRecvMsgRequestSigs sigIds) = + object + [ "kind" .= String "SigSubmissionOutboundRecvMsgRequestSigs" + , "sigIds" .= sigIds + ] + toJSON (TraceSigSubmissionOutboundSendMsgReplySigs sigs) = + object + [ "kind" .= String "SigSubmissionOutboundSendMsgReplySigs" + , "sigs" .= sigs + ] + +data SigSubmissionProtocolError = + ProtocolErrorAckedTooManySigIds + | ProtocolErrorRequestedNothing + | ProtocolErrorRequestedTooManySigIds NumIdsReq Word16 NumIdsAck + | ProtocolErrorRequestBlocking + | ProtocolErrorRequestNonBlocking + | ProtocolErrorRequestedUnavailableSig + deriving Show + +instance Exception SigSubmissionProtocolError where + displayException ProtocolErrorAckedTooManySigIds = + "The peer tried to acknowledged more sigIds than are available to do so." + + displayException (ProtocolErrorRequestedTooManySigIds reqNo unackedNo maxUnacked) = + "The peer requested " ++ show reqNo ++ " sigIds which would put the " + ++ "total in flight over the limit of " ++ show maxUnacked ++ "." + ++ " Number of unacked sigIds " ++ show unackedNo + + displayException ProtocolErrorRequestedNothing = + "The peer requested zero sigIds." + + displayException ProtocolErrorRequestBlocking = + "The peer made a blocking request for more sigIds when there are still " + ++ "unacknowledged sigIds. It should have used a non-blocking request." + + displayException ProtocolErrorRequestNonBlocking = + "The peer made a non-blocking request for more sigIds when there are " + ++ "no unacknowledged sigIds. It should have used a blocking request." + + displayException ProtocolErrorRequestedUnavailableSig = + "The peer requested a signature which is not available, either " + ++ "because it was never available or because it was previously requested." + + +sigSubmissionOutbound + :: forall version sigId sig idx m. + (Ord sigId, Ord idx, MonadSTM m, MonadThrow m) + => Tracer m (TraceSigSubmissionOutbound sigId sig) + -> NumIdsAck -- ^ Maximum number of unacknowledged sigIds allowed + -> TxSubmissionMempoolReader sigId sig idx m + -> version + -> SigSubmissionOutbound sigId sig m () +sigSubmissionOutbound tracer maxUnacked TxSubmissionMempoolReader{..} _version = + SigSubmissionOutbound (pure (client Seq.empty mempoolZeroIdx)) + where + client :: StrictSeq (sigId, idx) -> idx -> OutboundStIdle sigId sig m () + client !unackedSeq !lastIdx = + OutboundStIdle { recvMsgRequestSigIds, recvMsgRequestSigs } + where + recvMsgRequestSigIds :: forall blocking. + SingBlockingStyle blocking + -> NumIdsAck + -> NumIdsReq + -> m (OutboundStSigIds blocking sigId sig m ()) + recvMsgRequestSigIds blocking ackNo reqNo = do + when (getNumIdsAck ackNo > fromIntegral (Seq.length unackedSeq)) $ + throwIO ProtocolErrorAckedTooManySigIds + + let unackedNo = fromIntegral (Seq.length unackedSeq) + when ( unackedNo + - getNumIdsAck ackNo + + getNumIdsReq reqNo + > getNumIdsAck maxUnacked) $ + throwIO (ProtocolErrorRequestedTooManySigIds reqNo unackedNo maxUnacked) + + -- Update our tracking state to remove the number of sigIds that the + -- peer has acknowledged. + let !unackedSeq' = Seq.drop (fromIntegral ackNo) unackedSeq + + -- Update our tracking state with any extra sigs available. + let update sigs = + -- These sigs should all be fresh + assert (all (\(_, idx, _) -> idx > lastIdx) sigs) $ + let !unackedSeq'' = unackedSeq' <> Seq.fromList + [ (sigId, idx) | (sigId, idx, _) <- sigs ] + !lastIdx' + | null sigs = lastIdx + | otherwise = idx where (_, idx, _) = last sigs + sigs' :: [(sigId, SizeInBytes)] + sigs' = [ (sigId, size) | (sigId, _, size) <- sigs ] + client' = client unackedSeq'' lastIdx' + in (sigs', client') + + -- Grab info about any new sigs after the last sig idx we've seen, + -- up to the number that the peer has requested. + case blocking of + SingBlocking -> do + when (reqNo == 0) $ + throwIO ProtocolErrorRequestedNothing + unless (Seq.null unackedSeq') $ + throwIO ProtocolErrorRequestBlocking + + sigs <- atomically $ + do + MempoolSnapshot{mempoolTxIdsAfter} <- mempoolGetSnapshot + let sigs = mempoolTxIdsAfter lastIdx + check (not $ null sigs) + pure (take (fromIntegral reqNo) sigs) + + let !(sigs', client') = update sigs + sigs'' = case NonEmpty.nonEmpty sigs' of + Just x -> x + -- Assert sigs is non-empty: we blocked until sigs was non-null, + -- and we know reqNo > 0, hence `take reqNo sigs` is non-null. + Nothing -> error "sigSubmissionOutbound: empty signature list" + pure (SendMsgReplySigIds (BlockingReply sigs'') client') + + SingNonBlocking -> do + when (reqNo == 0 && ackNo == 0) $ + throwIO ProtocolErrorRequestedNothing + when (Seq.null unackedSeq') $ + throwIO ProtocolErrorRequestNonBlocking + + sigs <- atomically $ do + MempoolSnapshot{mempoolTxIdsAfter} <- mempoolGetSnapshot + let sigs = mempoolTxIdsAfter lastIdx + return (take (fromIntegral reqNo) sigs) + + let !(sigs', client') = update sigs + pure (SendMsgReplySigIds (NonBlockingReply sigs') client') + + recvMsgRequestSigs :: [sigId] + -> m (OutboundStSigs sigId sig m ()) + recvMsgRequestSigs sigIds = do + -- Trace the IDs of the signatures requested. + traceWith tracer (TraceSigSubmissionOutboundRecvMsgRequestSigs sigIds) + + MempoolSnapshot{mempoolLookupTx} <- atomically mempoolGetSnapshot + + -- The window size is expected to be small (currently 10) so the find is acceptable. + let sigIdxs = [ find (\(t,_) -> t == sigId) unackedSeq | sigId <- sigIds ] + sigIdxs' = map snd $ catMaybes sigIdxs + + when (any isNothing sigIdxs) $ + throwIO ProtocolErrorRequestedUnavailableSig + + -- The 'mempoolLookupTx' will return nothing if the signature is no + -- longer in the mempool. This is good. Neither the sending nor + -- receiving side wants to forward sigs that are no longer of interest. + let sigs = mapMaybe mempoolLookupTx sigIdxs' + client' = client unackedSeq lastIdx + + -- Trace the sigs to be sent in the response. + traceWith tracer (TraceSigSubmissionOutboundSendMsgReplySigs sigs) + + return $ SendMsgReplySigs sigs client' diff --git a/dmq-node/test/DMQ/Protocol/LocalMsgNotification/Test.hs b/dmq-node/test/DMQ/Protocol/LocalMsgNotification/Test.hs index 7260704..b8467a7 100644 --- a/dmq-node/test/DMQ/Protocol/LocalMsgNotification/Test.hs +++ b/dmq-node/test/DMQ/Protocol/LocalMsgNotification/Test.hs @@ -221,7 +221,6 @@ codec :: MonadST m => LocalMsgNotificationCodec m MsgWithBytes codec = codecLocalMsgNotification' Utils.runWithByteSpan encodeMsg decodeMsg - instance Arbitrary HasMore where arbitrary = elements [HasMore, DoesNotHaveMore] diff --git a/dmq-node/test/DMQ/Protocol/SigSubmission/Test.hs b/dmq-node/test/DMQ/Protocol/SigSubmission/Test.hs index 260fefd..1a74873 100644 --- a/dmq-node/test/DMQ/Protocol/SigSubmission/Test.hs +++ b/dmq-node/test/DMQ/Protocol/SigSubmission/Test.hs @@ -26,6 +26,7 @@ module DMQ.Protocol.SigSubmission.Test (tests) where import Codec.CBOR.Encoding qualified as CBOR import Codec.CBOR.Read qualified as CBOR import Codec.CBOR.Write qualified as CBOR +import Codec.CBOR.FlatTerm qualified as CBOR import Control.Monad (zipWithM, (>=>)) import Control.Monad.Class.MonadTime.SI import Control.Monad.ST (runST) @@ -85,6 +86,7 @@ tests = [ testGroup "MockCrypto" [ testProperty "OCert" prop_codec_ocert_mockcrypto , testProperty "Sig" prop_codec_sig_mockcrypto + , testProperty "Sig.encoding" prop_codec_sig_encoding_mockcrypto , testProperty "codec" prop_codec_mockcrypto , testProperty "codec id" prop_codec_id_mockcrypto , testProperty "codec 2-splits" $ withMaxSize 20 @@ -102,6 +104,7 @@ tests = , testGroup "StandardCrypto" [ testProperty "OCert" prop_codec_ocert_standardcrypto , testProperty "Sig" prop_codec_sig_standardcrypto + , testProperty "Sig.encoding" prop_codec_sig_encoding_standardcrypto , testProperty "codec" prop_codec_standardcrypto , testProperty "codec id" prop_codec_id_standardcrypto , testProperty "codec 2-splits" $ withMaxSize 20 @@ -715,6 +718,27 @@ prop_codec_sig_standardcrypto prop_codec_sig_standardcrypto = prop_codec_sig . getBlind +prop_codec_sig_encoding + :: forall crypto + . WithConstrKES (SeedSizeKES (KES crypto)) (KES crypto) (Sig crypto) + -> Property +prop_codec_sig_encoding constr = ioProperty $ do + sig <- runWithConstr constr + let encoding = encodeSig sig + return . counterexample (show sig) + $ CBOR.validFlatTerm (CBOR.toFlatTerm encoding) + +prop_codec_sig_encoding_mockcrypto + :: Blind (WithConstrKES (SeedSizeKES (KES MockCrypto)) (KES MockCrypto) (Sig MockCrypto)) + -> Property +prop_codec_sig_encoding_mockcrypto = prop_codec_sig_encoding . getBlind + +prop_codec_sig_encoding_standardcrypto + :: Blind (WithConstrKES (SeedSizeKES (KES StandardCrypto)) (KES StandardCrypto) (Sig StandardCrypto)) + -> Property +prop_codec_sig_encoding_standardcrypto = prop_codec_sig_encoding . getBlind + + type AnySigMessage crypto = WithConstrKESList (SeedSizeKES (KES crypto)) (KES crypto) (AnyMessage (SigSubmission crypto)) diff --git a/dmq-node/test/DMQ/Protocol/SigSubmissionV2/Codec/CDDL.hs b/dmq-node/test/DMQ/Protocol/SigSubmissionV2/Codec/CDDL.hs new file mode 100644 index 0000000..b995cd4 --- /dev/null +++ b/dmq-node/test/DMQ/Protocol/SigSubmissionV2/Codec/CDDL.hs @@ -0,0 +1,21 @@ +module DMQ.Protocol.SigSubmissionV2.Codec.CDDL where + +import Codec.CBOR.Read qualified as CBOR +import Codec.Serialise.Class qualified as Serialise +import Data.ByteString.Lazy qualified as BL + +import Network.TypedProtocol.Codec + +import DMQ.Protocol.SigSubmissionV2.Codec +import DMQ.Protocol.SigSubmissionV2.Test (Sig, SigId) +import DMQ.Protocol.SigSubmissionV2.Type hiding (Sig, SigId) + + +sigSubmissionV2Codec :: Codec (SigSubmissionV2 SigId Sig) + CBOR.DeserialiseFailure IO BL.ByteString +sigSubmissionV2Codec = + codecSigSubmissionV2 + Serialise.encode + Serialise.decode + Serialise.encode + Serialise.decode diff --git a/dmq-node/test/DMQ/Protocol/SigSubmissionV2/Direct.hs b/dmq-node/test/DMQ/Protocol/SigSubmissionV2/Direct.hs new file mode 100644 index 0000000..c84ff49 --- /dev/null +++ b/dmq-node/test/DMQ/Protocol/SigSubmissionV2/Direct.hs @@ -0,0 +1,71 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module DMQ.Protocol.SigSubmissionV2.Direct (directPipelined) where + +import Data.Map.Strict qualified as Map +import Data.List.NonEmpty qualified as NonEmpty + +import Network.TypedProtocol.Core +import Network.TypedProtocol.Proofs (Queue (..), enqueue) + +import DMQ.Protocol.SigSubmissionV2.Inbound +import DMQ.Protocol.SigSubmissionV2.Outbound +import DMQ.Protocol.SigSubmissionV2.Type (BlockingReplyList (..), + SingBlockingStyle (..)) + + +directPipelined + :: forall sigId sig m a. + Monad m + => SigSubmissionOutbound sigId sig m a + -> SigSubmissionInboundPipelined sigId sig m a + -> m a +directPipelined (SigSubmissionOutbound mOutbound) + (SigSubmissionInboundPipelined mInbound) = do + outbound <- mOutbound + inbound <- mInbound + directSender EmptyQ inbound outbound + where + directSender :: forall (n :: N). + Queue n (Collect sigId sig) + -> InboundStIdle n sigId sig m a + -> OutboundStIdle sigId sig m a + -> m a + directSender q (SendMsgRequestSigIdsBlocking ackNo reqNo inboundNext) + OutboundStIdle{recvMsgRequestSigIds} = do + reply <- recvMsgRequestSigIds SingBlocking ackNo reqNo + case reply of + SendMsgReplySigIds (BlockingReply sigIds) outbound' -> do + inbound' <- inboundNext (NonEmpty.toList sigIds) + directSender q inbound' outbound' + + SendMsgReplyNoSigIds outbound' -> do + inbound' <- inboundNext [] + directSender q inbound' outbound' + + directSender q (SendMsgRequestSigIdsPipelined ackNo reqNo inboundNext) + OutboundStIdle{recvMsgRequestSigIds} = do + reply <- recvMsgRequestSigIds SingNonBlocking ackNo reqNo + case reply of + SendMsgReplySigIds (NonBlockingReply sigIds) outbound' -> do + inbound' <- inboundNext + directSender (enqueue (CollectSigIds reqNo sigIds) q) inbound' outbound' + + directSender q (SendMsgRequestSigsPipelined sigIds inboundNext) + OutboundStIdle{recvMsgRequestSigs} = do + SendMsgReplySigs sigs outbound' <- recvMsgRequestSigs $ Map.keys sigIds + inbound' <- inboundNext + directSender (enqueue (CollectSigs sigIds sigs) q) inbound' outbound' + + directSender q (CollectPipelined (Just noWaitInbound') _inboundNext) outbound = do + directSender q noWaitInbound' outbound + + directSender (ConsQ c q) (CollectPipelined _maybeNoWaitInbound' inboundNext) outbound = do + inbound' <- inboundNext c + directSender q inbound' outbound + + directSender EmptyQ (SendMsgDone v) _outbound = v diff --git a/dmq-node/test/DMQ/Protocol/SigSubmissionV2/Test.hs b/dmq-node/test/DMQ/Protocol/SigSubmissionV2/Test.hs new file mode 100644 index 0000000..75102d9 --- /dev/null +++ b/dmq-node/test/DMQ/Protocol/SigSubmissionV2/Test.hs @@ -0,0 +1,249 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} + +{-# OPTIONS_GHC -Wno-orphans #-} +module DMQ.Protocol.SigSubmissionV2.Test + ( tests + , SigId (..) + , Sig (..) + ) where + +import Control.Monad.Class.MonadST (MonadST) +import Control.Monad.ST (runST) +import Codec.CBOR.FlatTerm qualified as CBOR +import Data.ByteString.Lazy (ByteString) +import Data.List.NonEmpty qualified as NonEmpty +import GHC.Generics + +import Codec.Serialise (DeserialiseFailure, Serialise) +import Codec.Serialise qualified as Serialise (decode, encode) + +import Network.TypedProtocol.Codec +import Network.TypedProtocol.Codec.Properties (prop_codecM, prop_codec_splitsM) + +import Ouroboros.Network.Util.ShowProxy + +import DMQ.Protocol.SigSubmissionV2.Codec +import DMQ.Protocol.SigSubmissionV2.Type hiding (Sig, SigId) + +import Test.Data.CDDL (Any (..)) +import Test.Ouroboros.Network.Protocol.Utils (prop_codec_cborM, + prop_codec_valid_cbor_encoding, splits2, splits3) +import Test.Ouroboros.Network.Utils (renderRanges) + +import Test.QuickCheck as QC +import Test.QuickCheck.Instances.ByteString () +import Test.Tasty (TestTree, testGroup) +import Test.Tasty.QuickCheck (testProperty) + + +-- +-- Test cases +-- + + +tests :: TestTree +tests = + testGroup "DMQ.Protocol" + [ testGroup "SigSubmissionV2" + [ testProperty "codec" prop_codec + , testProperty "encoding" prop_encoding + , testProperty "codec id" prop_codec_id + , testProperty "codec 2-splits" $ withMaxSize 50 + prop_codec_splits2 + , testProperty "codec 3-splits" $ withMaxSize 10 + prop_codec_splits3 + , testProperty "codec cbor" prop_codec_cbor + , testProperty "codec valid cbor" prop_codec_valid_cbor + ] + ] + +-- +-- Common types & clients and servers used in the tests in this module. +-- + +newtype Sig = Sig SigId + deriving (Eq, Show, Arbitrary, Serialise, Generic) + +instance ShowProxy Sig where + showProxy _ = "Sig" + +-- | We use any `CBOR.Term`. This allows us to use `any` in cddl specs. +-- +newtype SigId = SigId Any + deriving (Eq, Ord, Show, Arbitrary, Serialise, Generic) + +instance ShowProxy SigId where + showProxy _ = "SigId" + +deriving newtype instance Arbitrary SizeInBytes + +deriving newtype instance Arbitrary NumIdsAck +deriving newtype instance Arbitrary NumIdsReq + +instance Arbitrary (AnyMessage (SigSubmissionV2 SigId Sig)) where + arbitrary = oneof + [ AnyMessage + <$> ( MsgRequestSigIds SingBlocking + <$> arbitrary + <*> arbitrary + ) + + , AnyMessage + <$> ( MsgRequestSigIds SingNonBlocking + <$> arbitrary + <*> arbitrary + ) + + , AnyMessage + <$> MsgReplySigIds + <$> ( BlockingReply + . NonEmpty.fromList + . QC.getNonEmpty + ) + <$> arbitrary + + , AnyMessage + <$> MsgReplySigIds + <$> NonBlockingReply + <$> arbitrary + + , AnyMessage + <$> pure MsgReplyNoSigIds + + , AnyMessage + <$> MsgRequestSigs + <$> arbitrary + + , AnyMessage + <$> MsgReplySigs + <$> arbitrary + + , AnyMessage + <$> pure MsgDone + ] + +instance (Eq sigId + , Eq sig + ) + => Eq (AnyMessage (SigSubmissionV2 sigId sig)) where + + (==) (AnyMessage (MsgRequestSigIds SingBlocking ackNo reqNo)) + (AnyMessage (MsgRequestSigIds SingBlocking ackNo' reqNo')) = + (ackNo, reqNo) == (ackNo', reqNo') + + (==) (AnyMessage (MsgRequestSigIds SingNonBlocking ackNo reqNo)) + (AnyMessage (MsgRequestSigIds SingNonBlocking ackNo' reqNo')) = + (ackNo, reqNo) == (ackNo', reqNo') + + (==) (AnyMessage (MsgReplySigIds (BlockingReply sigIds))) + (AnyMessage (MsgReplySigIds (BlockingReply sigIds'))) = + sigIds == sigIds' + + (==) (AnyMessage (MsgReplySigIds (NonBlockingReply sigIds))) + (AnyMessage (MsgReplySigIds (NonBlockingReply sigIds'))) = + sigIds == sigIds' + + (==) (AnyMessage MsgReplyNoSigIds) + (AnyMessage MsgReplyNoSigIds) = True + + (==) (AnyMessage (MsgRequestSigs sigIds)) + (AnyMessage (MsgRequestSigs sigIds')) = sigIds == sigIds' + + (==) (AnyMessage (MsgReplySigs txs)) + (AnyMessage (MsgReplySigs txs')) = txs == txs' + + (==) (AnyMessage MsgDone) + (AnyMessage MsgDone) = True + + (==) _ _ = False + + +codec :: MonadST m + => Codec + (SigSubmissionV2 SigId Sig) + DeserialiseFailure + m ByteString +codec = codecSigSubmissionV2 + Serialise.encode Serialise.decode + Serialise.encode Serialise.decode + + +-- | Check the codec round trip property. +-- +prop_codec + :: AnyMessage (SigSubmissionV2 SigId Sig) + -> Property +prop_codec msg = + runST (prop_codecM codec msg) + + +prop_encoding :: AnyMessage (SigSubmissionV2 SigId Sig) + -> Property +prop_encoding msg@(AnyMessage msg') = + let enc = encodeSigSubmissionV2 Serialise.encode Serialise.encode msg' + terms = CBOR.toFlatTerm enc + in counterexample (show msg) + . counterexample ("terms: " ++ show terms) + $ CBOR.validFlatTerm terms + + +-- | Check the codec round trip property for the id codec. +-- +prop_codec_id + :: AnyMessage (SigSubmissionV2 SigId Sig) + -> Property +prop_codec_id msg = + runST (prop_codecM codecSigSubmissionV2Id msg) + +-- | Check for data chunk boundary problems in the codec using 2 chunks. +-- +prop_codec_splits2 + :: AnyMessage (SigSubmissionV2 SigId Sig) + -> Property +prop_codec_splits2 msg = + runST (prop_codec_splitsM splits2 codec msg) + +-- | Check for data chunk boundary problems in the codec using 3 chunks. +-- +prop_codec_splits3 + :: AnyMessage (SigSubmissionV2 SigId Sig) + -> Property +prop_codec_splits3 msg = + labelMsg msg $ + runST (prop_codec_splitsM splits3 codec msg) + +prop_codec_cbor + :: AnyMessage (SigSubmissionV2 SigId Sig) + -> Property +prop_codec_cbor msg = + runST (prop_codec_cborM codec msg) + +-- | Check that the encoder produces a valid CBOR. +-- +prop_codec_valid_cbor + :: AnyMessage (SigSubmissionV2 SigId Sig) + -> Property +prop_codec_valid_cbor = prop_codec_valid_cbor_encoding codec + + +labelMsg :: AnyMessage (SigSubmissionV2 sigId sig) -> Property -> Property +labelMsg (AnyMessage msg) = + label (case msg of + MsgRequestSigIds {} -> "MsgRequestSigIds" + MsgReplySigIds as -> "MsgReplySigIds " ++ renderRanges 3 (length as) + MsgReplyNoSigIds -> "MsgReplyNoSigIds" + MsgRequestSigs as -> "MsgRequestSigs " ++ renderRanges 3 (length as) + MsgReplySigs as -> "MsgReplySigs " ++ renderRanges 3 (length as) + MsgDone -> "MsgDone" + ) diff --git a/dmq-node/test/Main.hs b/dmq-node/test/Main.hs index 880fa9c..effacb1 100644 --- a/dmq-node/test/Main.hs +++ b/dmq-node/test/Main.hs @@ -6,10 +6,12 @@ import Cardano.Crypto.Libsodium import Test.DMQ.NodeToClient qualified import Test.DMQ.NodeToNode qualified +import Test.DMQ.SigSubmission.App qualified import DMQ.Protocol.LocalMsgNotification.Test qualified import DMQ.Protocol.LocalMsgSubmission.Test qualified import DMQ.Protocol.SigSubmission.Test qualified +import DMQ.Protocol.SigSubmissionV2.Test qualified import Test.Tasty @@ -25,9 +27,11 @@ tests = testGroup "decentralised-message-queue:tests" [ Test.DMQ.NodeToClient.tests , Test.DMQ.NodeToNode.tests + , Test.DMQ.SigSubmission.App.tests -- protocols , DMQ.Protocol.SigSubmission.Test.tests + , DMQ.Protocol.SigSubmissionV2.Test.tests , DMQ.Protocol.LocalMsgSubmission.Test.tests , DMQ.Protocol.LocalMsgNotification.Test.tests ] diff --git a/dmq-node/test/Test/DMQ/SigSubmission/App.hs b/dmq-node/test/Test/DMQ/SigSubmission/App.hs new file mode 100644 index 0000000..ace1728 --- /dev/null +++ b/dmq-node/test/Test/DMQ/SigSubmission/App.hs @@ -0,0 +1,409 @@ +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeOperators #-} + +{-# OPTIONS_GHC -Wno-orphans #-} + +module Test.DMQ.SigSubmission.App (tests) where + +import Prelude hiding (seq) + +import System.Random (mkStdGen) +import Control.Concurrent.Class.MonadMVar.Strict +import Control.Concurrent.Class.MonadSTM.Strict +import Control.Monad.Class.MonadAsync +import Control.Monad.Class.MonadFork +import Control.Monad.Class.MonadSay +import Control.Monad.Class.MonadST +import Control.Monad.Class.MonadThrow +import Control.Monad.Class.MonadTime.SI +import Control.Monad.Class.MonadTimer.SI +import Control.Monad.IOSim +import Control.Tracer (Tracer (..), contramap) + +import Data.ByteString.Lazy qualified as BSL +import Data.Foldable (traverse_) +import Data.Function (on) +import Data.Hashable +import Data.List (nubBy) +import Data.List qualified as List +import Data.Map.Strict (Map) +import Data.Map.Strict qualified as Map +import Data.Maybe (fromMaybe) +import Data.Set qualified as Set +import Data.Typeable (Typeable) + +import Ouroboros.Network.Channel +import Ouroboros.Network.ControlMessage (ControlMessage (..), ControlMessageSTM) +import Ouroboros.Network.Driver +import Ouroboros.Network.Protocol.TxSubmission2.Type (NumTxIdsToReq(..)) +import Ouroboros.Network.TxSubmission.Inbound.V2 +import Ouroboros.Network.Util.ShowProxy + +import DMQ.SigSubmissionV2.Outbound (sigSubmissionOutbound) +import DMQ.SigSubmissionV2.Inbound (sigSubmissionInbound) +import DMQ.Protocol.SigSubmissionV2.Type (NumIdsAck(..), SigSubmissionV2) +import DMQ.Protocol.SigSubmissionV2.Codec (byteLimitsSigSubmissionV2, + timeLimitsSigSubmissionV2) +import DMQ.Protocol.SigSubmissionV2.Outbound (sigSubmissionV2OutboundPeer) +import DMQ.Protocol.SigSubmissionV2.Inbound ( + sigSubmissionV2InboundPeerPipelined) + +import Test.DMQ.SigSubmission.Types (SigId, Sig (..), debugTracer, + verboseTracer, newMempool, emptyMempool, readMempool, + sigSubmissionCodec2, getMempoolReader, getMempoolWriter) + +import Test.Ouroboros.Network.TxSubmission.TxLogic (ArbTxDecisionPolicy (..)) +import Test.Ouroboros.Network.Utils hiding (debugTracer) + +import Test.QuickCheck +import Test.Tasty (TestTree, testGroup) +import Test.Tasty.QuickCheck (testProperty) + + +tests :: TestTree +tests = testGroup "Test.DMQ.SigSubmission.App" + [ testProperty "sigSubmission" prop_sigSubmission + ] + + +data TestVersion = TestVersion + deriving (Eq, Ord, Bounded, Enum, Show) + + +-- | Tests overall sig submission semantics. +-- This property test is the same as for tx submission v1. We need this to know +-- we didn't regress. +-- +prop_sigSubmission :: SigSubmissionState -> Property +prop_sigSubmission st@(SigSubmissionState peers _) = + let tr = runSimTrace (sigSubmissionSimulation st) + numPeersWithWronglySizedSig :: Int + numPeersWithWronglySizedSig = + foldr + (\(sigs, _, _) r -> + case List.find (\sig -> getSigSize sig /= getSigAdvSize sig) sigs of + Just {} -> r + 1 + Nothing -> r + ) 0 peers + in + label ("number of peers: " ++ renderRanges 3 (Map.size peers)) + . label ("number of sigs: " + ++ + renderRanges 10 + ( Set.size + . foldMap (Set.fromList . (\(sigs, _, _) -> getSigId <$> sigs)) + $ Map.elems peers + )) + . label ("number of peers with wrongly sized sig: " + ++ show numPeersWithWronglySizedSig) + $ case traceResult True tr of + Left e -> + counterexample (show e) + . counterexample (ppTrace tr) + $ False + Right (inmp, outmps) -> + counterexample (ppTrace tr) + $ conjoin (validate inmp `map` outmps) + where + validate :: [Sig SigId] -- the inbound mempool + -> [Sig SigId] -- one of the outbound mempools + -> Property + validate inmp outmp = + let outUniqueSigIds = nubBy (on (==) getSigId) outmp + outValidSigs = filterValidSigs outmp + in + case ( length outUniqueSigIds == length outmp + , length outValidSigs == length outmp + ) of + x@(True, True) -> + -- If we are presented with a stream of unique sigids for valid + -- signatures the inbound signatures should match the outbound + -- signatures exactly. + counterexample (show x) + . counterexample (show inmp) + . counterexample (show outmp) + $ checkMempools inmp (take (length inmp) outValidSigs) + + x@(True, False) | Nothing <- List.find (\sig -> getSigAdvSize sig /= getSigSize sig) outmp -> + -- If we are presented with a stream of unique sigids then we should have + -- fetched all valid signatures if all sigs have valid sizes. + counterexample (show x) + . counterexample (show inmp) + . counterexample (show outValidSigs) + + $ checkMempools inmp (take (length inmp) outValidSigs) + | otherwise -> + -- If there's one sig with an invalid size, we will download only + -- some of them, but we don't guarantee how many we will download. + -- + -- This is ok, the peer is cheating. + property True + + + x@(False, True) -> + -- If we are presented with a stream of valid sigids then we should have + -- fetched some version of those signatures. + counterexample (show x) + . counterexample (show inmp) + . counterexample (show outmp) + $ checkMempools (map getSigId inmp) + (take (length inmp) + (getSigId <$> filterValidSigs outUniqueSigIds)) + + (False, False) -> + -- If we are presented with a stream of valid and invalid Sigs with + -- duplicate sigids we're content with completing the protocol + -- without error. + property True + + +sigSubmissionSimulation :: forall s . SigSubmissionState + -> IOSim s ([Sig SigId], [[Sig SigId]]) + -- ^ inbound & outbound mempools +sigSubmissionSimulation (SigSubmissionState state sigDecisionPolicy) = do + state' <- traverse (\(sigs, mbOutDelay, mbInDelay) -> do + let mbOutDelayTime = getSmallDelay . getPositive <$> mbOutDelay + mbInDelayTime = getSmallDelay . getPositive <$> mbInDelay + controlMessageVar <- newTVarIO Continue + return ( sigs + , controlMessageVar + , mbOutDelayTime + , mbInDelayTime + ) + ) + state + + state'' <- traverse (\(sigs, var, mbOutDelay, mbInDelay) -> do + return ( sigs + , readTVar var + , mbOutDelay + , mbInDelay + ) + ) + state' + + let simDelayTime = Map.foldl' (\m (sigs, _, mbInDelay, mbOutDelay) -> + max m ( fromMaybe 1 (max <$> mbInDelay <*> mbOutDelay) + * realToFrac (length sigs `div` 4) + ) + ) + 0 + state'' + controlMessageVars = (\(_, x, _, _) -> x) + <$> Map.elems state' + + withAsync + (do threadDelay (simDelayTime + 1000) + atomically (traverse_ (`writeTVar` Terminate) controlMessageVars) + ) \_ -> do + let tracer :: forall a. (Show a, Typeable a) => Tracer (IOSim s) a + tracer = verboseTracer + <> debugTracer + <> Tracer traceM + runSigSubmission tracer tracer state'' sigDecisionPolicy + + +filterValidSigs :: [Sig SigId] -> [Sig SigId] +filterValidSigs + = filter getSigValid + . takeWhile (\Sig{getSigSize, getSigAdvSize} -> getSigSize == getSigAdvSize) + + +-- | Check that the inbound mempool contains all outbound `sig`s as a proper +-- subsequence. It might contain more `sig`s from other peers. +-- +checkMempools :: Eq sig + => [sig] -- inbound mempool + -> [sig] -- outbound mempool + -> Bool +checkMempools _ [] = True -- all outbound `sig` were found in the inbound + -- mempool +checkMempools [] (_:_) = False -- outbound mempool contains `sig`s which were + -- not transferred to the inbound mempool +checkMempools (i : is') os@(o : os') + | i == o + = checkMempools is' os' + + | otherwise + -- `_i` is not present in the outbound mempool, we can skip it. + = checkMempools is' os + + +newtype SigStateTrace peeraddr sigid = + SigStateTrace (SharedTxState peeraddr sigid (Sig sigid)) + +runSigSubmission + :: forall m peeraddr sigid. + ( MonadAsync m + , MonadDelay m + , MonadFork m + , MonadMask m + , MonadMVar m + , MonadSay m + , MonadST m + , MonadLabelledSTM m + , MonadTime m + , MonadTimer m + , MonadThrow (STM m) + , MonadTraceSTM m + , ShowProxy sigid + , Typeable sigid + , Show peeraddr + , Ord peeraddr + , Hashable peeraddr + , Typeable peeraddr + + , sigid ~ Int + ) + => Tracer m (String, TraceSendRecv (SigSubmissionV2 sigid (Sig sigid))) + -> Tracer m (TraceTxLogic peeraddr sigid (Sig sigid)) + -> Map peeraddr ( [Sig sigid] + , ControlMessageSTM m + , Maybe DiffTime + , Maybe DiffTime + ) + -> TxDecisionPolicy + -> m ([Sig sigid], [[Sig sigid]]) + -- ^ inbound and outbound mempools +runSigSubmission tracer tracerSigLogic st0 sigDecisionPolicy = do + st <- traverse (\(b, c, d, e) -> do + mempool <- newMempool b + (outChannel, inChannel) <- createConnectedChannels + return (mempool, c, d, e, outChannel, inChannel) + ) st0 + inboundMempool <- emptyMempool + let sigRng = mkStdGen 42 -- TODO + + sigChannelsVar <- newMVar (TxChannels Map.empty) + sigMempoolSem <- newTxMempoolSem + sharedSigStateVar <- newSharedTxStateVar sigRng + traceTVarIO sharedSigStateVar \_ -> return . TraceDynamic . SigStateTrace + labelTVarIO sharedSigStateVar "shared-sig-state" + + withAsync (decisionLogicThreads tracerSigLogic sayTracer + sigDecisionPolicy sigChannelsVar sharedSigStateVar) $ \a -> do + let servers = (\(addr, (mempool, _, _, inDelay, _, inChannel)) -> do + let server = sigSubmissionOutbound + (Tracer $ say . show) + (NumIdsAck $ getNumTxIdsToReq $ maxUnacknowledgedTxIds sigDecisionPolicy) + (getMempoolReader mempool) + (maxBound :: TestVersion) + runPeerWithLimits + (("OUTBOUND " ++ show addr,) `contramap` tracer) + sigSubmissionCodec2 + (byteLimitsSigSubmissionV2 (fromIntegral . BSL.length)) + timeLimitsSigSubmissionV2 + (maybe id delayChannel inDelay inChannel) + (sigSubmissionV2OutboundPeer server) + ) + <$> Map.assocs st + + let clients = (\(addr, (_, ctrlMsgSTM, outDelay, _, outChannel, _)) -> do + withPeer tracerSigLogic + sigChannelsVar + sigMempoolSem + sigDecisionPolicy + sharedSigStateVar + (getMempoolReader inboundMempool) + (getMempoolWriter inboundMempool) + getSigSize + addr $ \(api :: PeerTxAPI m SigId (Sig SigId))-> do + let client = sigSubmissionInbound + verboseTracer + NoTxSubmissionInitDelay + (getMempoolWriter inboundMempool) + api + ctrlMsgSTM + runPipelinedPeerWithLimits + (("INBOUND " ++ show addr,) `contramap` verboseTracer) + sigSubmissionCodec2 + (byteLimitsSigSubmissionV2 (fromIntegral . BSL.length)) + timeLimitsSigSubmissionV2 + (maybe id delayChannel outDelay outChannel) + (sigSubmissionV2InboundPeerPipelined client) + ) <$> Map.assocs st + + -- Run clients and servers + withAsyncAll (zip clients servers) $ \as -> do + _ <- waitAllClients as + -- cancel decision logic thread + cancel a + + inmp <- readMempool inboundMempool + let outmp = map (\(sigs, _, _, _) -> sigs) + $ Map.elems st0 + + return (inmp, outmp) + where + waitAllClients :: [(Async m x, Async m x)] -> m [Either SomeException x] + waitAllClients [] = return [] + waitAllClients ((client, server):as) = do + r <- waitCatch client + -- cancel server as soon as the client exits + cancel server + rs <- waitAllClients as + return (r : rs) + + withAsyncAll :: [(m a, m a)] + -> ([(Async m a, Async m a)] -> m b) + -> m b + withAsyncAll xs0 action = go [] xs0 + where + go as [] = action (reverse as) + go as ((x,y):xs) = withAsync x (\a -> withAsync y (\b -> go ((a, b):as) xs)) + + +data SigSubmissionState = + SigSubmissionState { + peerMap :: Map Int ( [Sig Int] + , Maybe (Positive SmallDelay) + , Maybe (Positive SmallDelay) + -- ^ The delay must be smaller (<) than 5s, so that overall + -- delay is less than 10s, otherwise 'smallDelay' in + -- 'timeLimitsTxSubmission2' will kick in. + ) + , decisionPolicy :: TxDecisionPolicy + } deriving (Show) + +instance Arbitrary SigSubmissionState where + arbitrary = do + ArbTxDecisionPolicy decisionPolicy <- arbitrary + peersN <- choose (1, 10) + txsN <- choose (1, 10) + -- NOTE: using sortOn would forces tx-decision logic to download txs in the + -- order of unacknowledgedTxIds. This could be useful to get better + -- properties when wrongly sized txs are present. + txs <- divvy txsN . nubBy (on (==) getSigId) {- . List.sortOn getSigId -} <$> vectorOf (peersN * txsN) arbitrary + peers <- vectorOf peersN arbitrary + peersState <- zipWith (curry (\(a, (b, c)) -> (a, b, c))) txs + <$> vectorOf peersN arbitrary + return SigSubmissionState { peerMap = Map.fromList (zip peers peersState), + decisionPolicy + } + where + -- | Split a list into sub list of at most `n` elements. + -- + divvy :: Int -> [a] -> [[a]] + divvy _ [] = [] + divvy n as = take n as : divvy n (drop n as) + + shrink SigSubmissionState { peerMap, decisionPolicy } = + SigSubmissionState <$> shrinkMap1 peerMap + <*> [ policy + | ArbTxDecisionPolicy policy <- shrink (ArbTxDecisionPolicy decisionPolicy) + ] + where + shrinkMap1 :: Ord k => Map k v -> [Map k v] + shrinkMap1 m + | Map.size m <= 1 = [m] + | otherwise = [Map.delete k m | k <- Map.keys m] ++ singletonMaps + where + singletonMaps = [Map.singleton k v | (k, v) <- Map.toList m] + diff --git a/dmq-node/test/Test/DMQ/SigSubmission/Types.hs b/dmq-node/test/Test/DMQ/SigSubmission/Types.hs new file mode 100644 index 0000000..a4f8a7f --- /dev/null +++ b/dmq-node/test/Test/DMQ/SigSubmission/Types.hs @@ -0,0 +1,230 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} + +module Test.DMQ.SigSubmission.Types + ( Sig (..) + , SigId + , Mempool + , emptyMempool + , newMempool + , readMempool + , getMempoolReader + , getMempoolWriter + , maxSigSize + , LargeNonEmptyList (..) + , SimResults (..) + , WithThreadAndTime (..) + , sigSubmissionCodec2 + , evaluateTrace + , verboseTracer + , debugTracer +) where + +import Prelude hiding (seq) + +import NoThunks.Class + +import Control.Concurrent.Class.MonadSTM +import Control.DeepSeq +import Control.Exception (SomeException (..)) +import Control.Monad.Class.MonadAsync +import Control.Monad.Class.MonadFork +import Control.Monad.Class.MonadSay +import Control.Monad.Class.MonadST +import Control.Monad.Class.MonadThrow +import Control.Monad.Class.MonadTime.SI +import Control.Monad.IOSim hiding (SimResult) +import Control.Tracer (Tracer (..), showTracing, traceWith) + +import Codec.CBOR.Decoding qualified as CBOR +import Codec.CBOR.Encoding qualified as CBOR +import Codec.CBOR.Read qualified as CBOR + +import Data.ByteString.Lazy (ByteString) +import GHC.Generics (Generic) + +import Network.TypedProtocol.Codec + +import Ouroboros.Network.Protocol.TxSubmission2.Type +import Ouroboros.Network.TxSubmission.Inbound.V1 (TxSubmissionMempoolWriter) +import Ouroboros.Network.TxSubmission.Mempool.Reader +import Ouroboros.Network.TxSubmission.Mempool.Simple (Mempool) +import Ouroboros.Network.TxSubmission.Mempool.Simple qualified as Mempool +import Ouroboros.Network.Util.ShowProxy + +import Test.QuickCheck +import Text.Printf +import DMQ.Protocol.SigSubmissionV2.Type (SigSubmissionV2) +import DMQ.Protocol.SigSubmissionV2.Codec (codecSigSubmissionV2) + + +data Sig sigid = Sig { + getSigId :: !sigid, + getSigSize :: !SizeInBytes, + getSigAdvSize :: !SizeInBytes, + -- | If false this means that when this sig will be submitted to a remote + -- mempool it will not be valid. The outbound mempool might contain + -- invalid sig's in this sense. + getSigValid :: !Bool + } + deriving (Eq, Ord, Show, Generic, NFData) + +instance NoThunks sigid => NoThunks (Sig sigid) +instance ShowProxy sigid => ShowProxy (Sig sigid) where + showProxy _ = "Sig " ++ showProxy (Proxy :: Proxy sigid) + +instance Arbitrary sigid => Arbitrary (Sig sigid) where + arbitrary = do + -- note: + -- generating small sig sizes avoids overflow error when semigroup + -- instance of `SizeInBytes` is used (summing up all inflight sig + -- sizes). + (size, advSize) <- frequency [ (99, (\a -> (a,a)) <$> chooseEnum (0, maxSigSize)) + , (1, (,) <$> chooseEnum (0, maxSigSize) <*> chooseEnum (0, maxSigSize)) + ] + Sig <$> arbitrary + <*> pure size + <*> pure advSize + <*> frequency [ (3, pure True) + , (1, pure False) + ] + +-- maximal sig size +maxSigSize :: SizeInBytes +maxSigSize = 65536 + +type SigId = Int + +emptyMempool :: MonadSTM m => m (Mempool m sigid (Sig sigid)) +emptyMempool = Mempool.empty + +newMempool :: (MonadSTM m, Ord sigid) + => [Sig sigid] -> m (Mempool m sigid (Sig sigid)) +newMempool = Mempool.new getSigId + +readMempool :: MonadSTM m => Mempool m sigid (Sig sigid) -> m [Sig sigid] +readMempool = Mempool.read + +getMempoolReader :: forall sigid m. + ( MonadSTM m + , Ord sigid + ) + => Mempool m sigid (Sig sigid) + -> TxSubmissionMempoolReader sigid (Sig sigid) Integer m +getMempoolReader = Mempool.getReader getSigId getSigAdvSize + +data InvalidSig = InvalidSig + +getMempoolWriter :: forall sigid m. + ( MonadSTM m + , MonadTime m + , Ord sigid + ) + => Mempool m sigid (Sig sigid) + -> TxSubmissionMempoolWriter sigid (Sig sigid) Integer m InvalidSig +getMempoolWriter = Mempool.getWriter InvalidSig + getSigId + (\_ sigs ->return + [ if getSigValid sig + then Right sig + else Left (getSigId sig, InvalidSig) + | sig <- sigs + ] + ) + (\_ -> return ()) + + +sigSubmissionCodec2 :: MonadST m + => Codec (SigSubmissionV2 SigId (Sig Int)) + CBOR.DeserialiseFailure m ByteString +sigSubmissionCodec2 = + codecSigSubmissionV2 CBOR.encodeInt CBOR.decodeInt + encodeSig decodeSig + where + encodeSig Sig {getSigId, getSigSize, getSigAdvSize, getSigValid} = + CBOR.encodeListLen 4 + <> CBOR.encodeInt getSigId + <> CBOR.encodeWord32 (getSizeInBytes getSigSize) + <> CBOR.encodeWord32 (getSizeInBytes getSigAdvSize) + <> CBOR.encodeBool getSigValid + + decodeSig = do + _ <- CBOR.decodeListLen + Sig <$> CBOR.decodeInt + <*> (SizeInBytes <$> CBOR.decodeWord32) + <*> (SizeInBytes <$> CBOR.decodeWord32) + <*> CBOR.decodeBool + + +newtype LargeNonEmptyList a = LargeNonEmpty { getLargeNonEmpty :: [a] } + deriving Show + +instance Arbitrary a => Arbitrary (LargeNonEmptyList a) where + arbitrary = + LargeNonEmpty <$> suchThat (resize 500 (listOf arbitrary)) ((>25) . length) + + +-- TODO: Belongs in iosim. +data SimResults a = SimReturn a [String] + | SimException SomeException [String] + | SimDeadLock [String] + +-- Traverses a list of trace events and returns the result along with all log messages. +-- Incase of a pure exception, ie an assert, all tracers evaluated so far are returned. +evaluateTrace :: SimTrace a -> IO (SimResults a) +evaluateTrace = go [] + where + go as tr = do + r <- try (evaluate tr) + case r of + Right (SimTrace _ _ _ (EventSay s) tr') -> go (s : as) tr' + Right (SimTrace _ _ _ _ tr' ) -> go as tr' + Right (SimPORTrace _ _ _ _ (EventSay s) tr') -> go (s : as) tr' + Right (SimPORTrace _ _ _ _ _ tr' ) -> go as tr' + Right (TraceMainReturn _ _ a _) -> pure $ SimReturn a (reverse as) + Right (TraceMainException _ _ e _) -> pure $ SimException e (reverse as) + Right (TraceDeadlock _ _) -> pure $ SimDeadLock (reverse as) + Right TraceLoop -> error "IOSimPOR step time limit exceeded" + Right (TraceInternalError e) -> error ("IOSim: " ++ e) + Left (SomeException e) -> pure $ SimException (SomeException e) (reverse as) + + +data WithThreadAndTime a = WithThreadAndTime { + wtatOccuredAt :: !Time + , wtatWithinThread :: !String + , wtatEvent :: !a + } + +instance (Show a) => Show (WithThreadAndTime a) where + show WithThreadAndTime {wtatOccuredAt, wtatWithinThread, wtatEvent} = + printf "%s: %s: %s" (show wtatOccuredAt) (show wtatWithinThread) (show wtatEvent) + +verboseTracer :: forall a m. + ( MonadAsync m + , MonadSay m + , MonadMonotonicTime m + , Show a + ) + => Tracer m a +verboseTracer = threadAndTimeTracer $ showTracing $ Tracer say + +debugTracer :: forall a s. Show a => Tracer (IOSim s) a +debugTracer = threadAndTimeTracer $ showTracing $ Tracer (traceM . show) + +threadAndTimeTracer :: forall a m. + ( MonadAsync m + , MonadMonotonicTime m + ) + => Tracer m (WithThreadAndTime a) -> Tracer m a +threadAndTimeTracer tr = Tracer $ \s -> do + !now <- getMonotonicTime + !tid <- myThreadId + traceWith tr $ WithThreadAndTime now (show tid) s