Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add thrift transport implementation (framed + unframed) #29

Merged
merged 1 commit into from
Jan 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies:
- bytestring >= 0.10 && < 0.11
- cereal >= 0.5.8.1 && < 0.6
- containers >= 0.5 && < 0.7
- network >= 3.1 && < 3.2
- text >= 1.2 && < 1.3
- unordered-containers >= 0.2 && < 0.3
- vector >= 0.10 && < 0.13
Expand All @@ -59,6 +60,7 @@ library:
- Pinch.Protocol
- Pinch.Protocol.Binary
- Pinch.Protocol.Compact
- Pinch.Transport
dependencies:
- array >= 0.5
- deepseq >= 1.3 && < 1.5
Expand Down
4 changes: 4 additions & 0 deletions pinch.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ library
Pinch.Protocol
Pinch.Protocol.Binary
Pinch.Protocol.Compact
Pinch.Transport
other-modules:
Pinch.Internal.Bits
Pinch.Internal.Pinchable.Parser
Expand All @@ -74,6 +75,7 @@ library
, deepseq >=1.3 && <1.5
, ghc-prim
, hashable >=1.2 && <1.4
, network >=3.1 && <3.2
, semigroups >=0.18 && <0.20
, text >=1.2 && <1.3
, unordered-containers >=0.2 && <0.3
Expand All @@ -96,6 +98,7 @@ test-suite pinch-spec
Pinch.Internal.ValueSpec
Pinch.Protocol.BinarySpec
Pinch.Protocol.CompactSpec
Pinch.TransportSpec
Paths_pinch
hs-source-dirs:
tests
Expand All @@ -109,6 +112,7 @@ test-suite pinch-spec
, cereal >=0.5.8.1 && <0.6
, containers >=0.5 && <0.7
, hspec >=2.0
, network >=3.1 && <3.2
, pinch
, semigroups >=0.18 && <0.20
, text >=1.2 && <1.3
Expand Down
6 changes: 6 additions & 0 deletions src/Pinch/Internal/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ module Pinch.Internal.Builder
, doubleBE
, doubleLE
, byteString

, getSize
) where

import Data.ByteString (ByteString)
Expand Down Expand Up @@ -145,3 +147,7 @@ byteString (BI.PS fp off len) =
withForeignPtr fp $ \src ->
BI.memcpy dst (src `plusPtr` off) len
{-# INLINE byteString #-}

-- | Returns the number of bytes in the builder.
getSize :: Builder -> Int
getSize (B sz _) = sz
124 changes: 124 additions & 0 deletions src/Pinch/Transport.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
{-# LANGUAGE RankNTypes #-}

module Pinch.Transport
( Transport(..)
, framedTransport
, unframedTransport
, Connection(..)
, ReadResult(..)
) where

import Data.IORef (newIORef, readIORef, writeIORef)
import Network.Socket (Socket)
import Network.Socket.ByteString (sendAll, recv)
import System.IO (Handle)

import qualified Data.ByteString as BS
import qualified Data.Serialize.Get as G

import qualified Pinch.Internal.Builder as B

class Connection c where
-- | Gets up to n bytes. Returns an empty bytestring if EOF is reached.
cGetSome :: c -> Int -> IO BS.ByteString
-- | Writes the given bytestring.
cPut :: c -> BS.ByteString -> IO ()

instance Connection Handle where
cPut = BS.hPut
cGetSome = BS.hGetSome

instance Connection Socket where
cPut = sendAll
cGetSome s n = recv s (min n 4096)

data ReadResult a
= RRSuccess a
| RRFailure String
| RREOF
deriving (Eq, Show)

-- | A bidirectional transport to read/write messages from/to.
data Transport
= Transport
{ writeMessage :: B.Builder -> IO ()
, readMessage :: forall a . G.Get a -> IO (ReadResult a)
}

-- | Creates a thrift framed transport. See also <https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md#framed-vs-unframed-transport>.
framedTransport :: Connection c => c -> IO Transport
framedTransport c = pure $ Transport writeMsg readMsg where
writeMsg msg = do
cPut c $ B.runBuilder $ B.int32BE (fromIntegral $ B.getSize msg)
cPut c $ B.runBuilder msg

readMsg p = do
szBs <- getExactly c 4
if BS.length szBs < 4
then
pure $ RREOF
else do
let sz = fromIntegral <$> G.runGet G.getInt32be szBs
case sz of
Right x -> do
msgBs <- getExactly c x
pure $ if BS.length msgBs < x
then
-- less data has been returned than expected. This means we have reached EOF.
RREOF
else
either RRFailure RRSuccess $ G.runGet p msgBs
Left s -> pure $ RRFailure $ "Invalid frame size: " ++ show s

-- | Creates a thrift unframed transport. See also <https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md#framed-vs-unframed-transport>.
unframedTransport :: Connection c => c -> IO Transport
unframedTransport c = do
-- As we do not know how long messages are,
-- we may read more data then the current message needs.
-- We keep the leftovers in a buffer so that we may use them
-- when reading the next message.
readBuffer <- newIORef mempty
pure $ Transport writeMsg (readMsg readBuffer)
where
writeMsg msg = cPut c $ B.runBuilder msg

readMsg buf p = do
bs <- readIORef buf
bs' <- if BS.null bs then getSome else pure bs
(leftOvers, r) <- runGetWith getSome p bs'
writeIORef buf leftOvers
pure $ r
getSome = cGetSome c 1024

-- | Runs a Get parser incrementally, reading more input as necessary until a successful parse
-- has been achieved.
runGetWith :: IO BS.ByteString -> G.Get a -> BS.ByteString -> IO (BS.ByteString, ReadResult a)
runGetWith getBs p initial = go (G.runGetPartial p initial)
where
go r = case r of
G.Fail err bs -> do
pure (bs, RRFailure err)
G.Done a bs -> do
pure (bs, RRSuccess a)
G.Partial cont -> do
bs <- getBs
if BS.null bs
then
-- EOF
pure (bs, RREOF)
else
go $ cont bs

-- | Gets exactly n bytes. If EOF is reached, an empty string is returned.
getExactly :: Connection c => c -> Int -> IO BS.ByteString
getExactly c n = B.runBuilder <$> go n mempty
where
go :: Int -> B.Builder -> IO B.Builder
go n b = do
bs <- cGetSome c n
let b' = b <> B.byteString bs
case BS.length bs of
-- EOF, return what data we might have gotten so far
0 -> pure mempty
n' | n' < n -> go (n - n') b'
_ | otherwise -> pure b'
71 changes: 67 additions & 4 deletions stack.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,68 @@
flags: {}
# This file was automatically generated by 'stack init'
#
# Some commonly used options have been documented as comments in this file.
# For advanced use and comprehensive documentation of the format, please see:
# https://docs.haskellstack.org/en/stable/yaml_configuration/

# Resolver to choose a 'specific' stackage snapshot or a compiler version.
# A snapshot resolver dictates the compiler version and the set of packages
# to be used for project dependencies. For example:
#
# resolver: lts-3.5
# resolver: nightly-2015-09-21
# resolver: ghc-7.10.2
#
# The location of a snapshot can be provided as a file or url. Stack assumes
# a snapshot provided as a file might change, whereas a url resource does not.
#
# resolver: ./custom-snapshot.yaml
# resolver: https://example.com/snapshots/2018-01-01.yaml
resolver: nightly-2020-12-14

# User packages to be built.
# Various formats can be used as shown in the example below.
#
# packages:
# - some-directory
# - https://example.com/foo/bar/baz-0.0.2.tar.gz
# subdirs:
# - auto-update
# - wai
packages:
- '.'
extra-deps: []
resolver: lts-11.13
- examples/keyvalue
- .
- bench/pinch-bench
# Dependency packages to be pulled from upstream that are not in the resolver.
# These entries can reference officially published versions as well as
# forks / in-progress versions pinned to a git hash. For example:
#
# extra-deps:
# - acme-missiles-0.3
# - git: https://github.com/commercialhaskell/stack.git
# commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a
#
# extra-deps: []

# Override default flag values for local packages and extra-deps
# flags: {}

# Extra package databases containing global packages
# extra-package-dbs: []

# Control whether we use the GHC we find on the path
# system-ghc: true
#
# Require a specific version of stack, using version ranges
# require-stack-version: -any # Default
# require-stack-version: ">=2.3"
#
# Override the architecture used by stack, especially useful on Windows
# arch: i386
# arch: x86_64
#
# Extra directories used by stack for building
# extra-include-dirs: [/path/to/dir]
# extra-lib-dirs: [/path/to/dir]
#
# Allow a newer minor version of GHC than the snapshot specifies
# compiler-check: newer-minor
97 changes: 97 additions & 0 deletions tests/Pinch/TransportSpec.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Pinch.TransportSpec (spec) where

import Data.ByteString (ByteString)
import Data.IORef (IORef, newIORef, readIORef, writeIORef, modifyIORef)
import Test.Hspec
import Test.Hspec.QuickCheck
import Test.QuickCheck

import qualified Data.ByteString as BS
import qualified Data.Serialize.Get as G

import Pinch.Arbitrary (SomeByteString(..))
import Pinch.Transport (Transport(..), framedTransport, unframedTransport, Connection(..), ReadResult(..))

import qualified Pinch.Internal.Builder as B

data MemoryConnection = MemoryConnection
{ contents :: IORef ByteString
, maxChunkSize :: Int -- ^ how many bytes to maximally return for one cGetSome call
}

newMemoryConnection :: Int -> IO MemoryConnection
newMemoryConnection ch = MemoryConnection <$> newIORef mempty <*> pure ch

mGetContents :: MemoryConnection -> IO ByteString
mGetContents = readIORef . contents

instance Connection MemoryConnection where
cGetSome (MemoryConnection ref ch) n = do
bytes <- readIORef ref
let (left, right) = BS.splitAt (min ch n) bytes
writeIORef ref right
return left
cPut (MemoryConnection ref _) newBytes = do
modifyIORef ref (<> newBytes)

transportSpec :: (forall c . Connection c => c -> IO Transport) -> Spec
transportSpec t = do
prop "can roundtrip bytestrings" $ \(Positive c, SomeByteString bytes) ->
ioProperty $ do
buf <- newMemoryConnection c
transp <- t buf
writeMessage transp (B.byteString bytes)
actual <- readMessage transp (G.getBytes $ BS.length bytes)
pure $ actual === RRSuccess bytes

it "EOF handling" $ do
buf <- newMemoryConnection 10
transp <- t buf
r <- readMessage transp (G.getInt8)
r `shouldBe` RREOF


spec :: Spec
spec = do
describe "framedTransport" $ do
transportSpec framedTransport

it "read case" $ do
let payload = BS.pack [0x01, 0x05, 0x01, 0x08, 0xFF]
buf <- newMemoryConnection 1
transp <- framedTransport buf
cPut buf $ BS.pack [0x00, 0x00, 0x00, 0x05]
cPut buf payload
r <- readMessage transp (G.getBytes $ BS.length payload)
r `shouldBe` RRSuccess payload

it "write case" $ do
let payload = BS.pack [0x01, 0x05, 0x01, 0x08, 0xFF]
buf <- newMemoryConnection 1
transp <- framedTransport buf
writeMessage transp (B.byteString payload)
actual <- mGetContents buf
actual `shouldBe` (BS.pack [0x00, 0x00, 0x00, 0x05] <> payload)


describe "unframedTransport" $ do
transportSpec unframedTransport

prop "read cases" $ \(SomeByteString payload) ->
ioProperty $ do
buf <- newMemoryConnection 1
transp <- unframedTransport buf
cPut buf payload
r <- readMessage transp (G.getBytes $ BS.length payload)
pure $ r === RRSuccess payload

prop "write cases" $ \(SomeByteString payload) ->
ioProperty $ do
buf <- newMemoryConnection 1
transp <- unframedTransport buf
writeMessage transp (B.byteString payload)
actual <- mGetContents buf
pure $ actual === payload