Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add lazy isolation mode #111

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cereal.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ test-suite test-cereal

build-depends: base == 4.*,
bytestring >= 0.9,
HUnit,
QuickCheck,
test-framework,
test-framework-hunit,
test-framework-quickcheck2,
cereal

Expand Down
143 changes: 112 additions & 31 deletions src/Data/Serialize/Get.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE PatternSynonyms #-}

-----------------------------------------------------------------------------
-- |
Expand Down Expand Up @@ -32,13 +34,14 @@ module Data.Serialize.Get (
, runGetLazyState

-- ** Incremental interface
, Result(..)
, Result(Fail, Partial, Done)
, runGetPartial
, runGetChunk

-- * Parsing
, ensure
, isolate
, isolateLazy
, label
, skip
, uncheckedSkip
Expand Down Expand Up @@ -128,7 +131,7 @@ import GHC.Word
#endif

-- | The result of a parse.
data Result r = Fail String B.ByteString
data Result r = FailRaw (String, [String]) B.ByteString
-- ^ The parse failed. The 'String' is the
-- message describing the error, if any.
| Partial (B.ByteString -> Result r)
Expand All @@ -140,13 +143,20 @@ data Result r = Fail String B.ByteString
-- input that had not yet been consumed (if any) when
-- the parse succeeded.

pattern Fail :: String -> B.ByteString -> Result r
pattern Fail msg bs <- FailRaw (formatFailure -> msg) bs
{-# COMPLETE Fail, Partial, Done #-}

formatFailure :: (String, [String]) -> String
formatFailure (err, stack) = unlines [err, formatTrace stack]

instance Show r => Show (Result r) where
show (Fail msg _) = "Fail " ++ show msg
show (FailRaw msg _) = "Fail " ++ show msg
show (Partial _) = "Partial _"
show (Done r bs) = "Done " ++ show r ++ " " ++ show bs

instance Functor Result where
fmap _ (Fail msg rest) = Fail msg rest
fmap _ (FailRaw a bs) = FailRaw a bs
fmap f (Partial k) = Partial (fmap f . k)
fmap f (Done r bs) = Done (f r) bs

Expand Down Expand Up @@ -275,7 +285,7 @@ finalK s _ _ _ a = Done a s

failK :: Failure a
failK s b _ ls msg =
Fail (unlines [msg, formatTrace ls]) (s `B.append` bufferBytes b)
FailRaw (msg, ls) (s `B.append` bufferBytes b)

-- | Run the Get monad applies a 'get'-based parser on the input ByteString
runGet :: Get a -> B.ByteString -> Either String a
Expand Down Expand Up @@ -364,9 +374,15 @@ runGetLazyState m lstr = case runGetLazy' m lstr of

-- | If at least @n@ bytes of input are available, return the current
-- input, otherwise fail.
{-# INLINE ensure #-}
ensure :: Int -> Get B.ByteString
ensure n0 = n0 `seq` Get $ \ s0 b0 m0 w0 kf ks -> let
ensure n
| n < 0 = fail "Attempted to ensure negative amount of bytes"
414owen marked this conversation as resolved.
Show resolved Hide resolved
| n == 0 = pure mempty
| otherwise = ensure' n

{-# INLINE ensure #-}
ensure' :: Int -> Get B.ByteString
ensure' n0 = n0 `seq` Get $ \ s0 b0 m0 w0 kf ks -> let
n' = n0 - B.length s0
in if n' <= 0
then ks s0 b0 m0 w0 s0
Expand Down Expand Up @@ -402,30 +418,93 @@ ensure n0 = n0 `seq` Get $ \ s0 b0 m0 w0 kf ks -> let
in ks s b m0 w0 s
else getMore n' s0 ss b0 m0 w0 kf ks

negativeIsolation :: Get a
negativeIsolation = fail "Attempted to isolate a negative number of bytes"

isolationUnderParse :: Get a
isolationUnderParse = fail "not all bytes parsed in isolate"

isolationUnderSupply :: Get a
isolationUnderSupply = failRaw "too few bytes" ["demandInput"]
414owen marked this conversation as resolved.
Show resolved Hide resolved

isolate0 :: Get a -> Get a
isolate0 parser = do
rest <- get
cur <- bytesRead
put mempty cur
a <- parser
put rest cur
pure a

-- | Isolate an action to operating within a fixed block of bytes. The action
-- is required to consume all the bytes that it is isolated to.
isolate :: Int -> Get a -> Get a
isolate n m = do
M.when (n < 0) (fail "Attempted to isolate a negative number of bytes")
s <- ensure n
let (s',rest) = B.splitAt n s
cur <- bytesRead
put s' cur
a <- m
used <- get
unless (B.null used) (fail "not all bytes parsed in isolate")
put rest (cur + n)
return a
isolate n m
| n < 0 = negativeIsolation
| n == 0 = isolate0 m
414owen marked this conversation as resolved.
Show resolved Hide resolved
| otherwise = do
s <- ensure' n
let (s',rest) = B.splitAt n s
cur <- bytesRead
put s' cur
a <- m
used <- get
unless (B.null used) isolationUnderParse
put rest (cur + n)
return a

getAtMost :: Int -> Get B.ByteString
getAtMost n = do
(bs, rest) <- B.splitAt n <$> ensure' 1
curr <- bytesRead
put rest (curr + B.length bs)
pure bs

-- | An incremental version of 'isolate', which doesn't try to read the input
-- into a buffer all at once.
isolateLazy :: Int -> Get a -> Get a
isolateLazy n parser
| n < 0 = negativeIsolation
| n == 0 = isolate0 parser
isolateLazy n parser = do
initialBytesRead <- bytesRead
go initialBytesRead . runGetPartial parser =<< getAtMost n
414owen marked this conversation as resolved.
Show resolved Hide resolved
where
go :: Int -> Result a -> Get a
go initialBytesRead r = case r of
FailRaw (msg, stack) bs -> bytesRead >>= put bs >> failRaw msg stack
414owen marked this conversation as resolved.
Show resolved Hide resolved
Done a bs
| otherwise -> do
414owen marked this conversation as resolved.
Show resolved Hide resolved
bytesRead' <- bytesRead
-- Technically this is both undersupply, and underparse
-- buyt we use undersupply to match strict isolation
unless (bytesRead' - initialBytesRead == n) isolationUnderSupply
unless (B.null bs) isolationUnderParse
pure a
Partial cont -> do
pos <- bytesRead
bs <- getAtMost $ n - (pos - initialBytesRead)
-- We want to give the inner parser a chance to determine
-- output, but if it returns a continuation, we'll throw
-- instead of recursing indefinitely
414owen marked this conversation as resolved.
Show resolved Hide resolved
if B.null bs
then case cont bs of
Partial cont -> isolationUnderSupply
a -> go initialBytesRead a
else go initialBytesRead $ cont bs

failRaw :: String -> [String] -> Get a
failRaw msg stack = Get (\s0 b0 m0 _ kf _ -> kf s0 b0 m0 stack msg)

failDesc :: String -> Get a
failDesc err = do
let msg = "Failed reading: " ++ err
Get (\s0 b0 m0 _ kf _ -> kf s0 b0 m0 [] msg)
failDesc err = failRaw ("Failed reading: " ++ err) []

-- | Skip ahead @n@ bytes. Fails if fewer than @n@ bytes are available.
skip :: Int -> Get ()
skip 0 = pure ()
skip n = do
s <- ensure n
M.when (n < 0) (fail "Attempted to skip a negative number of bytes")
s <- ensure' n
cur <- bytesRead
put (B.drop n s) (cur + n)

Expand Down Expand Up @@ -520,15 +599,17 @@ getShortByteString n = do

-- | Pull @n@ bytes from the input, as a strict ByteString.
getBytes :: Int -> Get B.ByteString
getBytes n | n < 0 = fail "getBytes: negative length requested"
getBytes n = do
s <- ensure n
let consume = B.unsafeTake n s
rest = B.unsafeDrop n s
-- (consume,rest) = B.splitAt n s
cur <- bytesRead
put rest (cur + n)
return consume
getBytes n
| n < 0 = fail "getBytes: negative length requested"
| n == 0 = pure mempty
| otherwise = do
s <- ensure' n
let consume = B.unsafeTake n s
rest = B.unsafeDrop n s
-- (consume,rest) = B.splitAt n s
cur <- bytesRead
put rest (cur + n)
return consume
{-# INLINE getBytes #-}


Expand Down
115 changes: 112 additions & 3 deletions tests/GetTests.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE OverloadedStrings #-}

module GetTests (tests) where

Expand All @@ -12,7 +14,12 @@ import qualified Data.ByteString.Lazy as LB
import Data.Serialize.Get
import Test.Framework (Test(),testGroup)
import Test.Framework.Providers.QuickCheck2 (testProperty)
import Test.QuickCheck as QC
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit (Assertion, (@=?), assertFailure)
import Test.QuickCheck hiding (Result)
import qualified Test.QuickCheck as QC
import Data.List (isInfixOf)
import Debug.Trace


-- Data to express Get parser to generate
Expand Down Expand Up @@ -123,10 +130,12 @@ isEmpty2 = do
pure True

-- Compare with chunks
(==~) :: Eq a => Get a -> Get a -> Property
(==~) :: (Eq a, Show a) => Get a -> Get a -> Property
p1 ==~ p2 =
conjoin
[ counterexample (show in0) $ R (runGetLazy p1 s) == R (runGetLazy p2 s)
[ let rl = runGetLazy p1 s
rr = runGetLazy p2 s
in counterexample (show (in0, n, s, rl, rr)) $ R rl == R rr
| n <- [0 .. testLength]
, let Chunks in0 = mkChunks n
s = LB.fromChunks [ BS.pack c | c <- in0 ]
Expand Down Expand Up @@ -254,6 +263,102 @@ alterDistr' p1 p2 p3 =
y = buildGet p2
z = buildGet p3

isolateLazyIsIncremental :: Assertion
isolateLazyIsIncremental = go (runGetPartial parser $ BS.replicate 11 0)
where
parser :: Get ()
parser = isolateLazy 100 $ do
skip 10
fail failStr
pure ()

failStr :: String
failStr = "no thanks"

go :: Result () -> IO ()
go r = case r of
Done () _ -> assertFailure "Impossible"
Fail failure _ -> unless (failStr `isInfixOf` failure) $ assertFailure "Wrong error!"
Partial cont -> assertFailure "Asked for more input!"

isolateLazyLeavesRemainingBytes :: Assertion
isolateLazyLeavesRemainingBytes = go (runGetPartial parser $ BS.replicate 11 0)
where
parser :: Get ()
parser = isolateLazy 100 $ do
skip 10
fail failStr
pure ()

failStr :: String
failStr = "no thanks"

go :: Result () -> IO ()
go r = case r of
Done () _ -> assertFailure "Impossible"
Fail failure _ -> unless (failStr `isInfixOf` failure) $ assertFailure "Wrong error!"
Partial cont -> assertFailure "Asked for more input!"

instance Arbitrary LB.ByteString where
arbitrary = LB.fromChunks . pure . BS.pack <$> arbitrary

newtype IsolationRes a = IRes (Either String a)
deriving Show

-- Sometimes lazy and strict isolations return different errors,
-- eg. when EOF is called before the end of an isolation which isn't prodided
-- enough input.
-- Strict sees it as a lack of bytes, Lazy sees it as a guard failure ("empty").
instance Eq a => Eq (IsolationRes a) where
IRes a == IRes b = case (a, b) of
(Left e1, Left e2) -> e1 == e2 || errsEqAsymmetric e1 e2 || errsEqAsymmetric e2 e1
_ -> a == b
where
errsEqAsymmetric e1 e2 = "too few bytes" `isInfixOf` e1 && "empty" `isInfixOf` e2

isolateAndIsolateLazy :: Int -> GetD -> LB.ByteString -> Property
isolateAndIsolateLazy n parser' bs
= IRes (runGetLazy (isolate n parser) bs) === IRes (runGetLazy (isolateLazy n parser) bs)
where
parser = buildGet parser'

isolateIsNotIncremental :: Assertion
isolateIsNotIncremental = go (runGetPartial parser $ BS.replicate 11 0)
where
parser :: Get ()
parser = isolate 100 $ do
skip 10
fail failStr
pure ()

failStr :: String
failStr = "no thanks"

go :: Result () -> IO ()
go r = case r of
Done () _ -> assertFailure "Impossible"
Fail failure _ -> assertFailure $ "Strict isolate was incremental: " <> failure
Partial cont -> pure ()

-- Checks return values, leftovers, fails for continuations
assertResultsMatch :: Eq a => Result a -> (Maybe a, BS.ByteString) -> Assertion
assertResultsMatch r1 r2 = case (r1, r2) of
(Partial _, _) -> assertFailure "Continuation received"
(Done a1 bs1, (Just a2, bs2)) -> do
unless (a1 == a2) $ assertFailure "Result mismatch"
unless (bs1 == bs2) $ assertFailure $ "Success leftover mismatch: " ++ show (bs1, bs2)
(Fail msg1 bs1, (Nothing, bs2)) ->
unless (bs1 == bs2) $ assertFailure $ "Failure leftovers mismatch: " ++ show (bs1, bs2)
_ -> assertFailure "Different result types"

isolateLazyDeterminesLeftovers :: Assertion
isolateLazyDeterminesLeftovers = do
assertResultsMatch (runGetPartial (isolateLazy 1 getWord8) "123") (Just $ toEnum $ fromEnum '1', "23")
assertResultsMatch (runGetPartial (isolateLazy 2 getWord8) "123") (Nothing, "3")
-- Note(414owen): I don't think this is the ideal behaviour, but it's the existing behaviour, so
-- I'll at least check that isolateLazy matches the behaviour of isolate...
assertResultsMatch (runGetPartial (isolate 2 $ fail "no thanks" *> pure ()) "123") (Nothing, "12")
414owen marked this conversation as resolved.
Show resolved Hide resolved
assertResultsMatch (runGetPartial (isolateLazy 2 $ fail "no thanks" *> pure ()) "123") (Nothing, "12")

tests :: Test
tests = testGroup "GetTests"
Expand All @@ -275,4 +380,8 @@ tests = testGroup "GetTests"
, testProperty "strict - alternative assoc" alterAssoc'
, testProperty "lazy - alternative distr" alterDistr
, testProperty "strict - alternative distr" alterDistr'
, testCase "isolate is not incremental" isolateIsNotIncremental
, testCase "isolateLazy is incremental" isolateLazyIsIncremental
, testProperty "isolations are equivalent" isolateAndIsolateLazy
, testCase "isolateLazy determines leftovers" isolateLazyDeterminesLeftovers
]