-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add thrift transport implementation (framed + unframed)
- Loading branch information
Showing
6 changed files
with
300 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
{-# LANGUAGE RankNTypes #-} | ||
|
||
module Pinch.Transport | ||
( Transport(..) | ||
, framedTransport | ||
, unframedTransport | ||
, Connection(..) | ||
, ReadResult(..) | ||
)where | ||
|
||
import Data.IORef (newIORef, readIORef, writeIORef) | ||
import Network.Socket (Socket) | ||
import Network.Socket.ByteString | ||
import System.IO (Handle) | ||
|
||
import qualified Data.ByteString as BS | ||
import qualified Data.Serialize.Get as G | ||
|
||
import qualified Pinch.Internal.Builder as B | ||
|
||
class Connection c where | ||
-- | Gets up to n bytes. Returns an empty bytestring if EOF is reached. | ||
cGetSome :: c -> Int -> IO BS.ByteString | ||
-- | Writes the given bytestring. | ||
cPut :: c -> BS.ByteString -> IO () | ||
|
||
instance Connection Handle where | ||
cPut = BS.hPut | ||
cGetSome = BS.hGetSome | ||
|
||
instance Connection Socket where | ||
cPut = sendAll | ||
cGetSome s n = recv s (min n 4096) | ||
|
||
data ReadResult a | ||
= RRSuccess a | ||
| RRFailure String | ||
| RREOF | ||
deriving (Eq, Show) | ||
|
||
-- | A bidirectional transport to read/write messages from/to. | ||
data Transport | ||
= Transport | ||
{ writeMessage :: B.Builder -> IO () | ||
, readMessage :: forall a . G.Get a -> IO (ReadResult a) | ||
} | ||
|
||
-- | Creates a thrift framed transport. See also https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md#framed-vs-unframed-transport . | ||
framedTransport :: Connection c => c -> IO Transport | ||
framedTransport c = pure $ Transport writeMsg readMsg where | ||
writeMsg msg = do | ||
cPut c $ B.runBuilder $ B.int32BE (fromIntegral $ B.getSize msg) | ||
cPut c $ B.runBuilder msg | ||
|
||
readMsg p = do | ||
szBs <- getExactly c 4 | ||
if BS.length szBs < 4 | ||
then | ||
pure $ RREOF | ||
else do | ||
let sz = fromIntegral <$> G.runGet G.getInt32be szBs | ||
case sz of | ||
Right x -> do | ||
msgBs <- getExactly c x | ||
if BS.length msgBs < x | ||
then | ||
-- less data has been returned than expected. This means we have reached EOF. | ||
pure $ RREOF | ||
else | ||
pure $ either RRFailure RRSuccess $ G.runGet p msgBs | ||
Left s -> pure $ RRFailure $ "Invalid frame size: " ++ show s | ||
|
||
-- | Creates a thrift unframed transport. See also https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md#framed-vs-unframed-transport . | ||
unframedTransport :: Connection c => c -> IO Transport | ||
unframedTransport c = do | ||
-- As we do not know how long messages are, | ||
-- we may read more data then the current message needs. | ||
-- We keep the leftovers in a buffer so that we may use them | ||
-- when reading the next message. | ||
readBuffer <- newIORef mempty | ||
pure $ Transport writeMsg (readMsg readBuffer) | ||
where | ||
writeMsg msg = cPut c $ B.runBuilder msg | ||
|
||
readMsg buf p = do | ||
bs <- readIORef buf | ||
bs' <- if BS.null bs then getSome else pure bs | ||
(leftOvers, r) <- runGetWith getSome p bs' | ||
writeIORef buf leftOvers | ||
pure $ r | ||
getSome = cGetSome c 1024 | ||
|
||
-- | Runs a Get parser incrementally, reading more input as necessary until a successful parse | ||
-- has been achieved. | ||
runGetWith :: IO BS.ByteString -> G.Get a -> BS.ByteString -> IO (BS.ByteString, ReadResult a) | ||
runGetWith getBs p initial = go (G.runGetPartial p initial) | ||
where | ||
go r = case r of | ||
G.Fail err bs -> do | ||
pure (bs, RRFailure err) | ||
G.Done a bs -> do | ||
pure (bs, RRSuccess a) | ||
G.Partial cont -> do | ||
bs <- getBs | ||
if BS.null bs | ||
then | ||
-- EOF | ||
pure (bs, RREOF) | ||
else | ||
go $ cont bs | ||
|
||
-- | Gets exactly n bytes. If EOF is reached, an empty string is returned. | ||
getExactly :: Connection c => c -> Int -> IO BS.ByteString | ||
getExactly c n = B.runBuilder <$> go n mempty | ||
where | ||
go :: Int -> B.Builder -> IO B.Builder | ||
go n b = do | ||
bs <- cGetSome c n | ||
let b' = b <> B.byteString bs | ||
case BS.length bs of | ||
-- EOF, return what data we might have gotten so far | ||
0 -> pure mempty | ||
n' | n' < n -> go (n - n') b' | ||
_ | otherwise -> pure b' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,68 @@ | ||
flags: {} | ||
# This file was automatically generated by 'stack init' | ||
# | ||
# Some commonly used options have been documented as comments in this file. | ||
# For advanced use and comprehensive documentation of the format, please see: | ||
# https://docs.haskellstack.org/en/stable/yaml_configuration/ | ||
|
||
# Resolver to choose a 'specific' stackage snapshot or a compiler version. | ||
# A snapshot resolver dictates the compiler version and the set of packages | ||
# to be used for project dependencies. For example: | ||
# | ||
# resolver: lts-3.5 | ||
# resolver: nightly-2015-09-21 | ||
# resolver: ghc-7.10.2 | ||
# | ||
# The location of a snapshot can be provided as a file or url. Stack assumes | ||
# a snapshot provided as a file might change, whereas a url resource does not. | ||
# | ||
# resolver: ./custom-snapshot.yaml | ||
# resolver: https://example.com/snapshots/2018-01-01.yaml | ||
resolver: nightly-2020-12-14 | ||
|
||
# User packages to be built. | ||
# Various formats can be used as shown in the example below. | ||
# | ||
# packages: | ||
# - some-directory | ||
# - https://example.com/foo/bar/baz-0.0.2.tar.gz | ||
# subdirs: | ||
# - auto-update | ||
# - wai | ||
packages: | ||
- '.' | ||
extra-deps: [] | ||
resolver: lts-11.13 | ||
- examples/keyvalue | ||
- . | ||
- bench/pinch-bench | ||
# Dependency packages to be pulled from upstream that are not in the resolver. | ||
# These entries can reference officially published versions as well as | ||
# forks / in-progress versions pinned to a git hash. For example: | ||
# | ||
# extra-deps: | ||
# - acme-missiles-0.3 | ||
# - git: https://github.com/commercialhaskell/stack.git | ||
# commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a | ||
# | ||
# extra-deps: [] | ||
|
||
# Override default flag values for local packages and extra-deps | ||
# flags: {} | ||
|
||
# Extra package databases containing global packages | ||
# extra-package-dbs: [] | ||
|
||
# Control whether we use the GHC we find on the path | ||
# system-ghc: true | ||
# | ||
# Require a specific version of stack, using version ranges | ||
# require-stack-version: -any # Default | ||
# require-stack-version: ">=2.3" | ||
# | ||
# Override the architecture used by stack, especially useful on Windows | ||
# arch: i386 | ||
# arch: x86_64 | ||
# | ||
# Extra directories used by stack for building | ||
# extra-include-dirs: [/path/to/dir] | ||
# extra-lib-dirs: [/path/to/dir] | ||
# | ||
# Allow a newer minor version of GHC than the snapshot specifies | ||
# compiler-check: newer-minor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
{-# LANGUAGE OverloadedStrings #-} | ||
{-# LANGUAGE RankNTypes #-} | ||
{-# LANGUAGE ScopedTypeVariables #-} | ||
module Pinch.TransportSpec (spec) where | ||
|
||
import Data.ByteString (ByteString) | ||
import Data.IORef (IORef, newIORef, readIORef, writeIORef, modifyIORef) | ||
import Test.Hspec | ||
import Test.Hspec.QuickCheck | ||
import Test.QuickCheck | ||
|
||
import qualified Data.ByteString as BS | ||
import qualified Data.Serialize.Get as G | ||
|
||
import Pinch.Arbitrary (SomeByteString(..)) | ||
import Pinch.Transport (Transport(..), framedTransport, unframedTransport, Connection(..), ReadResult(..)) | ||
|
||
import qualified Pinch.Internal.Builder as B | ||
|
||
data MemoryConnection = MemoryConnection | ||
{ contents :: IORef ByteString | ||
, maxChunkSize :: Int -- ^ how many bytes to maximally return for one cGetSome call | ||
} | ||
|
||
newMemoryConnection :: Int -> IO MemoryConnection | ||
newMemoryConnection ch = MemoryConnection <$> newIORef mempty <*> pure ch | ||
|
||
mGetContents :: MemoryConnection -> IO ByteString | ||
mGetContents = readIORef . contents | ||
|
||
instance Connection MemoryConnection where | ||
cGetSome (MemoryConnection ref ch) n = do | ||
bytes <- readIORef ref | ||
let (left, right) = BS.splitAt (min ch n) bytes | ||
writeIORef ref right | ||
return left | ||
cPut (MemoryConnection ref _) newBytes = do | ||
modifyIORef ref (<> newBytes) | ||
|
||
transportSpec :: (forall c . Connection c => c -> IO Transport) -> Spec | ||
transportSpec t = do | ||
prop "can roundtrip bytestrings" $ \(Positive c, SomeByteString bytes) -> | ||
ioProperty $ do | ||
buf <- newMemoryConnection c | ||
transp <- t buf | ||
writeMessage transp (B.byteString bytes) | ||
actual <- readMessage transp (G.getBytes $ BS.length bytes) | ||
pure $ actual === RRSuccess bytes | ||
|
||
it "EOF handling" $ do | ||
buf <- newMemoryConnection 10 | ||
transp <- t buf | ||
r <- readMessage transp (G.getInt8) | ||
r `shouldBe` RREOF | ||
|
||
|
||
spec :: Spec | ||
spec = do | ||
describe "framedTransport" $ do | ||
transportSpec framedTransport | ||
|
||
it "read case" $ do | ||
let payload = BS.pack [0x01, 0x05, 0x01, 0x08, 0xFF] | ||
buf <- newMemoryConnection 1 | ||
transp <- framedTransport buf | ||
cPut buf $ BS.pack [0x00, 0x00, 0x00, 0x05] | ||
cPut buf payload | ||
r <- readMessage transp (G.getBytes $ BS.length payload) | ||
r `shouldBe` RRSuccess payload | ||
|
||
it "write case" $ do | ||
let payload = BS.pack [0x01, 0x05, 0x01, 0x08, 0xFF] | ||
buf <- newMemoryConnection 1 | ||
transp <- framedTransport buf | ||
writeMessage transp (B.byteString payload) | ||
actual <- mGetContents buf | ||
actual `shouldBe` (BS.pack [0x00, 0x00, 0x00, 0x05] <> payload) | ||
|
||
|
||
describe "unframedTransport" $ do | ||
transportSpec unframedTransport | ||
|
||
prop "read cases" $ \(SomeByteString payload) -> | ||
ioProperty $ do | ||
buf <- newMemoryConnection 1 | ||
transp <- unframedTransport buf | ||
cPut buf payload | ||
r <- readMessage transp (G.getBytes $ BS.length payload) | ||
pure $ r === RRSuccess payload | ||
|
||
prop "write cases" $ \(SomeByteString payload) -> | ||
ioProperty $ do | ||
buf <- newMemoryConnection 1 | ||
transp <- unframedTransport buf | ||
writeMessage transp (B.byteString payload) | ||
actual <- mGetContents buf | ||
pure $ actual === payload |