Skip to content

Commit

Permalink
Add thrift transport implementation (framed + unframed)
Browse files Browse the repository at this point in the history
  • Loading branch information
phile314 committed Dec 30, 2020
1 parent a10fee9 commit c867cf0
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 4 deletions.
2 changes: 2 additions & 0 deletions package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies:
- bytestring >= 0.10 && < 0.11
- cereal >= 0.5.8.1 && < 0.6
- containers >= 0.5 && < 0.7
- network >= 3.1 && < 3.2
- text >= 1.2 && < 1.3
- unordered-containers >= 0.2 && < 0.3
- vector >= 0.10 && < 0.13
Expand All @@ -59,6 +60,7 @@ library:
- Pinch.Protocol
- Pinch.Protocol.Binary
- Pinch.Protocol.Compact
- Pinch.Transport
dependencies:
- array >= 0.5
- deepseq >= 1.3 && < 1.5
Expand Down
4 changes: 4 additions & 0 deletions pinch.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ library
Pinch.Protocol
Pinch.Protocol.Binary
Pinch.Protocol.Compact
Pinch.Transport
other-modules:
Pinch.Internal.Bits
Pinch.Internal.Pinchable.Parser
Expand All @@ -74,6 +75,7 @@ library
, deepseq >=1.3 && <1.5
, ghc-prim
, hashable >=1.2 && <1.4
, network >=3.1 && <3.2
, semigroups >=0.18 && <0.20
, text >=1.2 && <1.3
, unordered-containers >=0.2 && <0.3
Expand All @@ -96,6 +98,7 @@ test-suite pinch-spec
Pinch.Internal.ValueSpec
Pinch.Protocol.BinarySpec
Pinch.Protocol.CompactSpec
Pinch.TransportSpec
Paths_pinch
hs-source-dirs:
tests
Expand All @@ -109,6 +112,7 @@ test-suite pinch-spec
, cereal >=0.5.8.1 && <0.6
, containers >=0.5 && <0.7
, hspec >=2.0
, network >=3.1 && <3.2
, pinch
, semigroups >=0.18 && <0.20
, text >=1.2 && <1.3
Expand Down
6 changes: 6 additions & 0 deletions src/Pinch/Internal/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ module Pinch.Internal.Builder
, doubleBE
, doubleLE
, byteString

, getSize
) where

import Data.ByteString (ByteString)
Expand Down Expand Up @@ -145,3 +147,7 @@ byteString (BI.PS fp off len) =
withForeignPtr fp $ \src ->
BI.memcpy dst (src `plusPtr` off) len
{-# INLINE byteString #-}

-- | Returns the number of bytes in the builder.
getSize :: Builder -> Int
getSize (B sz _) = sz
124 changes: 124 additions & 0 deletions src/Pinch/Transport.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
{-# LANGUAGE RankNTypes #-}

module Pinch.Transport
( Transport(..)
, framedTransport
, unframedTransport
, Connection(..)
, ReadResult(..)
)where

import Data.IORef (newIORef, readIORef, writeIORef)
import Network.Socket (Socket)
import Network.Socket.ByteString
import System.IO (Handle)

import qualified Data.ByteString as BS
import qualified Data.Serialize.Get as G

import qualified Pinch.Internal.Builder as B

class Connection c where
-- | Gets up to n bytes. Returns an empty bytestring if EOF is reached.
cGetSome :: c -> Int -> IO BS.ByteString
-- | Writes the given bytestring.
cPut :: c -> BS.ByteString -> IO ()

instance Connection Handle where
cPut = BS.hPut
cGetSome = BS.hGetSome

instance Connection Socket where
cPut = sendAll
cGetSome s n = recv s (min n 4096)

data ReadResult a
= RRSuccess a
| RRFailure String
| RREOF
deriving (Eq, Show)

-- | A bidirectional transport to read/write messages from/to.
data Transport
= Transport
{ writeMessage :: B.Builder -> IO ()
, readMessage :: forall a . G.Get a -> IO (ReadResult a)
}

-- | Creates a thrift framed transport. See also https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md#framed-vs-unframed-transport .
framedTransport :: Connection c => c -> IO Transport
framedTransport c = pure $ Transport writeMsg readMsg where
writeMsg msg = do
cPut c $ B.runBuilder $ B.int32BE (fromIntegral $ B.getSize msg)
cPut c $ B.runBuilder msg

readMsg p = do
szBs <- getExactly c 4
if BS.length szBs < 4
then
pure $ RREOF
else do
let sz = fromIntegral <$> G.runGet G.getInt32be szBs
case sz of
Right x -> do
msgBs <- getExactly c x
if BS.length msgBs < x
then
-- less data has been returned than expected. This means we have reached EOF.
pure $ RREOF
else
pure $ either RRFailure RRSuccess $ G.runGet p msgBs
Left s -> pure $ RRFailure $ "Invalid frame size: " ++ show s

-- | Creates a thrift unframed transport. See also https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md#framed-vs-unframed-transport .
unframedTransport :: Connection c => c -> IO Transport
unframedTransport c = do
-- As we do not know how long messages are,
-- we may read more data then the current message needs.
-- We keep the leftovers in a buffer so that we may use them
-- when reading the next message.
readBuffer <- newIORef mempty
pure $ Transport writeMsg (readMsg readBuffer)
where
writeMsg msg = cPut c $ B.runBuilder msg

readMsg buf p = do
bs <- readIORef buf
bs' <- if BS.null bs then getSome else pure bs
(leftOvers, r) <- runGetWith getSome p bs'
writeIORef buf leftOvers
pure $ r
getSome = cGetSome c 1024

-- | Runs a Get parser incrementally, reading more input as necessary until a successful parse
-- has been achieved.
runGetWith :: IO BS.ByteString -> G.Get a -> BS.ByteString -> IO (BS.ByteString, ReadResult a)
runGetWith getBs p initial = go (G.runGetPartial p initial)
where
go r = case r of
G.Fail err bs -> do
pure (bs, RRFailure err)
G.Done a bs -> do
pure (bs, RRSuccess a)
G.Partial cont -> do
bs <- getBs
if BS.null bs
then
-- EOF
pure (bs, RREOF)
else
go $ cont bs

-- | Gets exactly n bytes. If EOF is reached, an empty string is returned.
getExactly :: Connection c => c -> Int -> IO BS.ByteString
getExactly c n = B.runBuilder <$> go n mempty
where
go :: Int -> B.Builder -> IO B.Builder
go n b = do
bs <- cGetSome c n
let b' = b <> B.byteString bs
case BS.length bs of
-- EOF, return what data we might have gotten so far
0 -> pure mempty
n' | n' < n -> go (n - n') b'
_ | otherwise -> pure b'
71 changes: 67 additions & 4 deletions stack.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,68 @@
flags: {}
# This file was automatically generated by 'stack init'
#
# Some commonly used options have been documented as comments in this file.
# For advanced use and comprehensive documentation of the format, please see:
# https://docs.haskellstack.org/en/stable/yaml_configuration/

# Resolver to choose a 'specific' stackage snapshot or a compiler version.
# A snapshot resolver dictates the compiler version and the set of packages
# to be used for project dependencies. For example:
#
# resolver: lts-3.5
# resolver: nightly-2015-09-21
# resolver: ghc-7.10.2
#
# The location of a snapshot can be provided as a file or url. Stack assumes
# a snapshot provided as a file might change, whereas a url resource does not.
#
# resolver: ./custom-snapshot.yaml
# resolver: https://example.com/snapshots/2018-01-01.yaml
resolver: nightly-2020-12-14

# User packages to be built.
# Various formats can be used as shown in the example below.
#
# packages:
# - some-directory
# - https://example.com/foo/bar/baz-0.0.2.tar.gz
# subdirs:
# - auto-update
# - wai
packages:
- '.'
extra-deps: []
resolver: lts-11.13
- examples/keyvalue
- .
- bench/pinch-bench
# Dependency packages to be pulled from upstream that are not in the resolver.
# These entries can reference officially published versions as well as
# forks / in-progress versions pinned to a git hash. For example:
#
# extra-deps:
# - acme-missiles-0.3
# - git: https://github.com/commercialhaskell/stack.git
# commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a
#
# extra-deps: []

# Override default flag values for local packages and extra-deps
# flags: {}

# Extra package databases containing global packages
# extra-package-dbs: []

# Control whether we use the GHC we find on the path
# system-ghc: true
#
# Require a specific version of stack, using version ranges
# require-stack-version: -any # Default
# require-stack-version: ">=2.3"
#
# Override the architecture used by stack, especially useful on Windows
# arch: i386
# arch: x86_64
#
# Extra directories used by stack for building
# extra-include-dirs: [/path/to/dir]
# extra-lib-dirs: [/path/to/dir]
#
# Allow a newer minor version of GHC than the snapshot specifies
# compiler-check: newer-minor
97 changes: 97 additions & 0 deletions tests/Pinch/TransportSpec.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Pinch.TransportSpec (spec) where

import Data.ByteString (ByteString)
import Data.IORef (IORef, newIORef, readIORef, writeIORef, modifyIORef)
import Test.Hspec
import Test.Hspec.QuickCheck
import Test.QuickCheck

import qualified Data.ByteString as BS
import qualified Data.Serialize.Get as G

import Pinch.Arbitrary (SomeByteString(..))
import Pinch.Transport (Transport(..), framedTransport, unframedTransport, Connection(..), ReadResult(..))

import qualified Pinch.Internal.Builder as B

data MemoryConnection = MemoryConnection
{ contents :: IORef ByteString
, maxChunkSize :: Int -- ^ how many bytes to maximally return for one cGetSome call
}

newMemoryConnection :: Int -> IO MemoryConnection
newMemoryConnection ch = MemoryConnection <$> newIORef mempty <*> pure ch

mGetContents :: MemoryConnection -> IO ByteString
mGetContents = readIORef . contents

instance Connection MemoryConnection where
cGetSome (MemoryConnection ref ch) n = do
bytes <- readIORef ref
let (left, right) = BS.splitAt (min ch n) bytes
writeIORef ref right
return left
cPut (MemoryConnection ref _) newBytes = do
modifyIORef ref (<> newBytes)

transportSpec :: (forall c . Connection c => c -> IO Transport) -> Spec
transportSpec t = do
prop "can roundtrip bytestrings" $ \(Positive c, SomeByteString bytes) ->
ioProperty $ do
buf <- newMemoryConnection c
transp <- t buf
writeMessage transp (B.byteString bytes)
actual <- readMessage transp (G.getBytes $ BS.length bytes)
pure $ actual === RRSuccess bytes

it "EOF handling" $ do
buf <- newMemoryConnection 10
transp <- t buf
r <- readMessage transp (G.getInt8)
r `shouldBe` RREOF


spec :: Spec
spec = do
describe "framedTransport" $ do
transportSpec framedTransport

it "read case" $ do
let payload = BS.pack [0x01, 0x05, 0x01, 0x08, 0xFF]
buf <- newMemoryConnection 1
transp <- framedTransport buf
cPut buf $ BS.pack [0x00, 0x00, 0x00, 0x05]
cPut buf payload
r <- readMessage transp (G.getBytes $ BS.length payload)
r `shouldBe` RRSuccess payload

it "write case" $ do
let payload = BS.pack [0x01, 0x05, 0x01, 0x08, 0xFF]
buf <- newMemoryConnection 1
transp <- framedTransport buf
writeMessage transp (B.byteString payload)
actual <- mGetContents buf
actual `shouldBe` (BS.pack [0x00, 0x00, 0x00, 0x05] <> payload)


describe "unframedTransport" $ do
transportSpec unframedTransport

prop "read cases" $ \(SomeByteString payload) ->
ioProperty $ do
buf <- newMemoryConnection 1
transp <- unframedTransport buf
cPut buf payload
r <- readMessage transp (G.getBytes $ BS.length payload)
pure $ r === RRSuccess payload

prop "write cases" $ \(SomeByteString payload) ->
ioProperty $ do
buf <- newMemoryConnection 1
transp <- unframedTransport buf
writeMessage transp (B.byteString payload)
actual <- mGetContents buf
pure $ actual === payload

0 comments on commit c867cf0

Please sign in to comment.