diff --git a/package.yaml b/package.yaml index 4534ba8..06bbd04 100644 --- a/package.yaml +++ b/package.yaml @@ -37,6 +37,7 @@ dependencies: - bytestring >= 0.10 && < 0.11 - cereal >= 0.5.8.1 && < 0.6 - containers >= 0.5 && < 0.7 + - network >= 3.1 && < 3.2 - text >= 1.2 && < 1.3 - unordered-containers >= 0.2 && < 0.3 - vector >= 0.10 && < 0.13 @@ -59,6 +60,7 @@ library: - Pinch.Protocol - Pinch.Protocol.Binary - Pinch.Protocol.Compact + - Pinch.Transport dependencies: - array >= 0.5 - deepseq >= 1.3 && < 1.5 diff --git a/pinch.cabal b/pinch.cabal index f59c91c..bcafd2a 100644 --- a/pinch.cabal +++ b/pinch.cabal @@ -56,6 +56,7 @@ library Pinch.Protocol Pinch.Protocol.Binary Pinch.Protocol.Compact + Pinch.Transport other-modules: Pinch.Internal.Bits Pinch.Internal.Pinchable.Parser @@ -74,6 +75,7 @@ library , deepseq >=1.3 && <1.5 , ghc-prim , hashable >=1.2 && <1.4 + , network >=3.1 && <3.2 , semigroups >=0.18 && <0.20 , text >=1.2 && <1.3 , unordered-containers >=0.2 && <0.3 @@ -96,6 +98,7 @@ test-suite pinch-spec Pinch.Internal.ValueSpec Pinch.Protocol.BinarySpec Pinch.Protocol.CompactSpec + Pinch.TransportSpec Paths_pinch hs-source-dirs: tests @@ -109,6 +112,7 @@ test-suite pinch-spec , cereal >=0.5.8.1 && <0.6 , containers >=0.5 && <0.7 , hspec >=2.0 + , network >=3.1 && <3.2 , pinch , semigroups >=0.18 && <0.20 , text >=1.2 && <1.3 diff --git a/src/Pinch/Internal/Builder.hs b/src/Pinch/Internal/Builder.hs index 114f448..304aae2 100644 --- a/src/Pinch/Internal/Builder.hs +++ b/src/Pinch/Internal/Builder.hs @@ -25,6 +25,8 @@ module Pinch.Internal.Builder , doubleBE , doubleLE , byteString + + , getSize ) where import Data.ByteString (ByteString) @@ -145,3 +147,7 @@ byteString (BI.PS fp off len) = withForeignPtr fp $ \src -> BI.memcpy dst (src `plusPtr` off) len {-# INLINE byteString #-} + +-- | Returns the number of bytes in the builder. +getSize :: Builder -> Int +getSize (B sz _) = sz diff --git a/src/Pinch/Transport.hs b/src/Pinch/Transport.hs new file mode 100644 index 0000000..3bb0f3d --- /dev/null +++ b/src/Pinch/Transport.hs @@ -0,0 +1,124 @@ +{-# LANGUAGE RankNTypes #-} + +module Pinch.Transport + ( Transport(..) + , framedTransport + , unframedTransport + , Connection(..) + , ReadResult(..) + )where + +import Data.IORef (newIORef, readIORef, writeIORef) +import Network.Socket (Socket) +import Network.Socket.ByteString +import System.IO (Handle) + +import qualified Data.ByteString as BS +import qualified Data.Serialize.Get as G + +import qualified Pinch.Internal.Builder as B + +class Connection c where + -- | Gets up to n bytes. Returns an empty bytestring if EOF is reached. + cGetSome :: c -> Int -> IO BS.ByteString + -- | Writes the given bytestring. + cPut :: c -> BS.ByteString -> IO () + +instance Connection Handle where + cPut = BS.hPut + cGetSome = BS.hGetSome + +instance Connection Socket where + cPut = sendAll + cGetSome s n = recv s (min n 4096) + +data ReadResult a + = RRSuccess a + | RRFailure String + | RREOF + deriving (Eq, Show) + +-- | A bidirectional transport to read/write messages from/to. +data Transport + = Transport + { writeMessage :: B.Builder -> IO () + , readMessage :: forall a . G.Get a -> IO (ReadResult a) + } + +-- | Creates a thrift framed transport. See also https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md#framed-vs-unframed-transport . +framedTransport :: Connection c => c -> IO Transport +framedTransport c = pure $ Transport writeMsg readMsg where + writeMsg msg = do + cPut c $ B.runBuilder $ B.int32BE (fromIntegral $ B.getSize msg) + cPut c $ B.runBuilder msg + + readMsg p = do + szBs <- getExactly c 4 + if BS.length szBs < 4 + then + pure $ RREOF + else do + let sz = fromIntegral <$> G.runGet G.getInt32be szBs + case sz of + Right x -> do + msgBs <- getExactly c x + if BS.length msgBs < x + then + -- less data has been returned than expected. This means we have reached EOF. + pure $ RREOF + else + pure $ either RRFailure RRSuccess $ G.runGet p msgBs + Left s -> pure $ RRFailure $ "Invalid frame size: " ++ show s + +-- | Creates a thrift unframed transport. See also https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md#framed-vs-unframed-transport . +unframedTransport :: Connection c => c -> IO Transport +unframedTransport c = do + -- As we do not know how long messages are, + -- we may read more data then the current message needs. + -- We keep the leftovers in a buffer so that we may use them + -- when reading the next message. + readBuffer <- newIORef mempty + pure $ Transport writeMsg (readMsg readBuffer) + where + writeMsg msg = cPut c $ B.runBuilder msg + + readMsg buf p = do + bs <- readIORef buf + bs' <- if BS.null bs then getSome else pure bs + (leftOvers, r) <- runGetWith getSome p bs' + writeIORef buf leftOvers + pure $ r + getSome = cGetSome c 1024 + +-- | Runs a Get parser incrementally, reading more input as necessary until a successful parse +-- has been achieved. +runGetWith :: IO BS.ByteString -> G.Get a -> BS.ByteString -> IO (BS.ByteString, ReadResult a) +runGetWith getBs p initial = go (G.runGetPartial p initial) + where + go r = case r of + G.Fail err bs -> do + pure (bs, RRFailure err) + G.Done a bs -> do + pure (bs, RRSuccess a) + G.Partial cont -> do + bs <- getBs + if BS.null bs + then + -- EOF + pure (bs, RREOF) + else + go $ cont bs + +-- | Gets exactly n bytes. If EOF is reached, an empty string is returned. +getExactly :: Connection c => c -> Int -> IO BS.ByteString +getExactly c n = B.runBuilder <$> go n mempty + where + go :: Int -> B.Builder -> IO B.Builder + go n b = do + bs <- cGetSome c n + let b' = b <> B.byteString bs + case BS.length bs of + -- EOF, return what data we might have gotten so far + 0 -> pure mempty + n' | n' < n -> go (n - n') b' + _ | otherwise -> pure b' diff --git a/stack.yaml b/stack.yaml index 75c7fbd..28c17ed 100644 --- a/stack.yaml +++ b/stack.yaml @@ -1,5 +1,68 @@ -flags: {} +# This file was automatically generated by 'stack init' +# +# Some commonly used options have been documented as comments in this file. +# For advanced use and comprehensive documentation of the format, please see: +# https://docs.haskellstack.org/en/stable/yaml_configuration/ + +# Resolver to choose a 'specific' stackage snapshot or a compiler version. +# A snapshot resolver dictates the compiler version and the set of packages +# to be used for project dependencies. For example: +# +# resolver: lts-3.5 +# resolver: nightly-2015-09-21 +# resolver: ghc-7.10.2 +# +# The location of a snapshot can be provided as a file or url. Stack assumes +# a snapshot provided as a file might change, whereas a url resource does not. +# +# resolver: ./custom-snapshot.yaml +# resolver: https://example.com/snapshots/2018-01-01.yaml +resolver: nightly-2020-12-14 + +# User packages to be built. +# Various formats can be used as shown in the example below. +# +# packages: +# - some-directory +# - https://example.com/foo/bar/baz-0.0.2.tar.gz +# subdirs: +# - auto-update +# - wai packages: -- '.' -extra-deps: [] -resolver: lts-11.13 +- examples/keyvalue +- . +- bench/pinch-bench +# Dependency packages to be pulled from upstream that are not in the resolver. +# These entries can reference officially published versions as well as +# forks / in-progress versions pinned to a git hash. For example: +# +# extra-deps: +# - acme-missiles-0.3 +# - git: https://github.com/commercialhaskell/stack.git +# commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a +# +# extra-deps: [] + +# Override default flag values for local packages and extra-deps +# flags: {} + +# Extra package databases containing global packages +# extra-package-dbs: [] + +# Control whether we use the GHC we find on the path +# system-ghc: true +# +# Require a specific version of stack, using version ranges +# require-stack-version: -any # Default +# require-stack-version: ">=2.3" +# +# Override the architecture used by stack, especially useful on Windows +# arch: i386 +# arch: x86_64 +# +# Extra directories used by stack for building +# extra-include-dirs: [/path/to/dir] +# extra-lib-dirs: [/path/to/dir] +# +# Allow a newer minor version of GHC than the snapshot specifies +# compiler-check: newer-minor diff --git a/tests/Pinch/TransportSpec.hs b/tests/Pinch/TransportSpec.hs new file mode 100644 index 0000000..80aa88e --- /dev/null +++ b/tests/Pinch/TransportSpec.hs @@ -0,0 +1,97 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +module Pinch.TransportSpec (spec) where + +import Data.ByteString (ByteString) +import Data.IORef (IORef, newIORef, readIORef, writeIORef, modifyIORef) +import Test.Hspec +import Test.Hspec.QuickCheck +import Test.QuickCheck + +import qualified Data.ByteString as BS +import qualified Data.Serialize.Get as G + +import Pinch.Arbitrary (SomeByteString(..)) +import Pinch.Transport (Transport(..), framedTransport, unframedTransport, Connection(..), ReadResult(..)) + +import qualified Pinch.Internal.Builder as B + +data MemoryConnection = MemoryConnection + { contents :: IORef ByteString + , maxChunkSize :: Int -- ^ how many bytes to maximally return for one cGetSome call + } + +newMemoryConnection :: Int -> IO MemoryConnection +newMemoryConnection ch = MemoryConnection <$> newIORef mempty <*> pure ch + +mGetContents :: MemoryConnection -> IO ByteString +mGetContents = readIORef . contents + +instance Connection MemoryConnection where + cGetSome (MemoryConnection ref ch) n = do + bytes <- readIORef ref + let (left, right) = BS.splitAt (min ch n) bytes + writeIORef ref right + return left + cPut (MemoryConnection ref _) newBytes = do + modifyIORef ref (<> newBytes) + +transportSpec :: (forall c . Connection c => c -> IO Transport) -> Spec +transportSpec t = do + prop "can roundtrip bytestrings" $ \(Positive c, SomeByteString bytes) -> + ioProperty $ do + buf <- newMemoryConnection c + transp <- t buf + writeMessage transp (B.byteString bytes) + actual <- readMessage transp (G.getBytes $ BS.length bytes) + pure $ actual === RRSuccess bytes + + it "EOF handling" $ do + buf <- newMemoryConnection 10 + transp <- t buf + r <- readMessage transp (G.getInt8) + r `shouldBe` RREOF + + +spec :: Spec +spec = do + describe "framedTransport" $ do + transportSpec framedTransport + + it "read case" $ do + let payload = BS.pack [0x01, 0x05, 0x01, 0x08, 0xFF] + buf <- newMemoryConnection 1 + transp <- framedTransport buf + cPut buf $ BS.pack [0x00, 0x00, 0x00, 0x05] + cPut buf payload + r <- readMessage transp (G.getBytes $ BS.length payload) + r `shouldBe` RRSuccess payload + + it "write case" $ do + let payload = BS.pack [0x01, 0x05, 0x01, 0x08, 0xFF] + buf <- newMemoryConnection 1 + transp <- framedTransport buf + writeMessage transp (B.byteString payload) + actual <- mGetContents buf + actual `shouldBe` (BS.pack [0x00, 0x00, 0x00, 0x05] <> payload) + + + describe "unframedTransport" $ do + transportSpec unframedTransport + + prop "read cases" $ \(SomeByteString payload) -> + ioProperty $ do + buf <- newMemoryConnection 1 + transp <- unframedTransport buf + cPut buf payload + r <- readMessage transp (G.getBytes $ BS.length payload) + pure $ r === RRSuccess payload + + prop "write cases" $ \(SomeByteString payload) -> + ioProperty $ do + buf <- newMemoryConnection 1 + transp <- unframedTransport buf + writeMessage transp (B.byteString payload) + actual <- mGetContents buf + pure $ actual === payload