diff --git a/.gitignore b/.gitignore index 2c50f15..c88a919 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ cabal.sandbox.config *.o *.hi *.swp +cabal.project.local diff --git a/cereal.cabal b/cereal.cabal index bb8c7c0..ceb34f0 100644 --- a/cereal.cabal +++ b/cereal.cabal @@ -66,8 +66,10 @@ test-suite test-cereal build-depends: base == 4.*, bytestring >= 0.9, + HUnit, QuickCheck, test-framework, + test-framework-hunit, test-framework-quickcheck2, cereal diff --git a/src/Data/Serialize/Get.hs b/src/Data/Serialize/Get.hs index 692eea4..ecb65ef 100644 --- a/src/Data/Serialize/Get.hs +++ b/src/Data/Serialize/Get.hs @@ -2,6 +2,8 @@ {-# LANGUAGE MagicHash #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE PatternSynonyms #-} ----------------------------------------------------------------------------- -- | @@ -32,13 +34,14 @@ module Data.Serialize.Get ( , runGetLazyState -- ** Incremental interface - , Result(..) + , Result(Fail, Partial, Done) , runGetPartial , runGetChunk -- * Parsing , ensure , isolate + , isolateLazy , label , skip , uncheckedSkip @@ -128,7 +131,7 @@ import GHC.Word #endif -- | The result of a parse. -data Result r = Fail String B.ByteString +data Result r = FailRaw (String, [String]) B.ByteString -- ^ The parse failed. The 'String' is the -- message describing the error, if any. | Partial (B.ByteString -> Result r) @@ -140,13 +143,20 @@ data Result r = Fail String B.ByteString -- input that had not yet been consumed (if any) when -- the parse succeeded. +pattern Fail :: String -> B.ByteString -> Result r +pattern Fail msg bs <- FailRaw (formatFailure -> msg) bs +{-# COMPLETE Fail, Partial, Done #-} + +formatFailure :: (String, [String]) -> String +formatFailure (err, stack) = unlines [err, formatTrace stack] + instance Show r => Show (Result r) where - show (Fail msg _) = "Fail " ++ show msg + show (FailRaw msg _) = "Fail " ++ show msg show (Partial _) = "Partial _" show (Done r bs) = "Done " ++ show r ++ " " ++ show bs instance Functor Result where - fmap _ (Fail msg rest) = Fail msg rest + fmap _ (FailRaw a bs) = FailRaw a bs fmap f (Partial k) = Partial (fmap f . k) fmap f (Done r bs) = Done (f r) bs @@ -275,7 +285,7 @@ finalK s _ _ _ a = Done a s failK :: Failure a failK s b _ ls msg = - Fail (unlines [msg, formatTrace ls]) (s `B.append` bufferBytes b) + FailRaw (msg, ls) (s `B.append` bufferBytes b) -- | Run the Get monad applies a 'get'-based parser on the input ByteString runGet :: Get a -> B.ByteString -> Either String a @@ -364,9 +374,14 @@ runGetLazyState m lstr = case runGetLazy' m lstr of -- | If at least @n@ bytes of input are available, return the current -- input, otherwise fail. -{-# INLINE ensure #-} ensure :: Int -> Get B.ByteString -ensure n0 = n0 `seq` Get $ \ s0 b0 m0 w0 kf ks -> let +ensure n + | n < 0 = fail "Attempted to ensure negative amount of bytes" + | otherwise = ensure' n + +{-# INLINE ensure #-} +ensure' :: Int -> Get B.ByteString +ensure' n0 = n0 `seq` Get $ \ s0 b0 m0 w0 kf ks -> let n' = n0 - B.length s0 in if n' <= 0 then ks s0 b0 m0 w0 s0 @@ -402,30 +417,83 @@ ensure n0 = n0 `seq` Get $ \ s0 b0 m0 w0 kf ks -> let in ks s b m0 w0 s else getMore n' s0 ss b0 m0 w0 kf ks +negativeIsolation :: Get a +negativeIsolation = fail "Attempted to isolate a negative number of bytes" + +isolationUnderParse :: Get a +isolationUnderParse = fail "Isolated parser didn't consume all input" + +isolationUnderSupply :: Get a +isolationUnderSupply = fail "Too few bytes supplied to isolated parser" + -- | Isolate an action to operating within a fixed block of bytes. The action -- is required to consume all the bytes that it is isolated to. isolate :: Int -> Get a -> Get a -isolate n m = do - M.when (n < 0) (fail "Attempted to isolate a negative number of bytes") - s <- ensure n - let (s',rest) = B.splitAt n s - cur <- bytesRead - put s' cur - a <- m - used <- get - unless (B.null used) (fail "not all bytes parsed in isolate") - put rest (cur + n) - return a +isolate n m + | n < 0 = negativeIsolation + | otherwise = do + s <- ensure' n + let (s',rest) = B.splitAt n s + cur <- bytesRead + put s' cur + a <- m + used <- get + unless (B.null used) isolationUnderParse + put rest (cur + n) + return a + +getAtMost :: Int -> Get B.ByteString +getAtMost 0 = pure mempty +getAtMost n = do + (bs, rest) <- B.splitAt n <$> ensure' 1 + curr <- bytesRead + put rest (curr + B.length bs) + pure bs + +-- | An incremental version of 'isolate', which doesn't try to read the input +-- into a buffer all at once. +isolateLazy :: Int -> Get a -> Get a +isolateLazy n _ | n < 0 = negativeIsolation +isolateLazy n parser = do + initialBytesRead <- bytesRead + bs <- getAtMost n + go initialBytesRead $ runGetPartial parser bs + where + go :: Int -> Result a -> Get a + go initialBytesRead r = case r of + FailRaw (msg, stack) bs -> do + m <- bytesRead + put bs m + failRaw msg stack + Done a bs -> do + bytesRead' <- bytesRead + unless (bytesRead' - initialBytesRead == n && B.null bs) + $ fail $ "Isolated parser didn't consume all input. " + <> "Internal leftovers: " <> show bs + <> ", bytesRead: " <> show (bytesRead' - initialBytesRead) + <> ", isolation amt: " <> show n + pure a + Partial cont -> do + pos <- bytesRead + let bytesLeft = n - (pos - initialBytesRead) + if bytesLeft == 0 + then isolationUnderSupply + else do + bs <- getAtMost bytesLeft + go initialBytesRead $ cont bs + +failRaw :: String -> [String] -> Get a +failRaw msg stack = Get (\s0 b0 m0 _ kf _ -> kf s0 b0 m0 stack msg) failDesc :: String -> Get a -failDesc err = do - let msg = "Failed reading: " ++ err - Get (\s0 b0 m0 _ kf _ -> kf s0 b0 m0 [] msg) +failDesc err = failRaw ("Failed reading: " ++ err) [] -- | Skip ahead @n@ bytes. Fails if fewer than @n@ bytes are available. skip :: Int -> Get () +skip 0 = pure () skip n = do - s <- ensure n + M.when (n < 0) (fail "Attempted to skip a negative number of bytes") + s <- ensure' n cur <- bytesRead put (B.drop n s) (cur + n) @@ -520,15 +588,16 @@ getShortByteString n = do -- | Pull @n@ bytes from the input, as a strict ByteString. getBytes :: Int -> Get B.ByteString -getBytes n | n < 0 = fail "getBytes: negative length requested" -getBytes n = do - s <- ensure n - let consume = B.unsafeTake n s - rest = B.unsafeDrop n s - -- (consume,rest) = B.splitAt n s - cur <- bytesRead - put rest (cur + n) - return consume +getBytes n + | n < 0 = fail "getBytes: negative length requested" + | otherwise = do + s <- ensure' n + let consume = B.unsafeTake n s + rest = B.unsafeDrop n s + -- (consume,rest) = B.splitAt n s + cur <- bytesRead + put rest (cur + n) + return consume {-# INLINE getBytes #-} diff --git a/tests/GetTests.hs b/tests/GetTests.hs index 3686279..f3df9f8 100644 --- a/tests/GetTests.hs +++ b/tests/GetTests.hs @@ -1,5 +1,7 @@ {-# OPTIONS_GHC -fno-warn-orphans #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE OverloadedStrings #-} module GetTests (tests) where @@ -12,7 +14,15 @@ import qualified Data.ByteString.Lazy as LB import Data.Serialize.Get import Test.Framework (Test(),testGroup) import Test.Framework.Providers.QuickCheck2 (testProperty) -import Test.QuickCheck as QC +import Test.Framework.Providers.HUnit (testCase) +import Test.HUnit (Assertion, (@=?), (@?=), assertFailure, (@?)) +import Test.QuickCheck hiding (Result) +import qualified Test.QuickCheck as QC +import Data.List (isInfixOf) +import Debug.Trace +import Data.Either (isLeft) +import Data.Bifunctor (bimap, Bifunctor (..)) +import Data.Functor (($>)) -- Data to express Get parser to generate @@ -123,10 +133,12 @@ isEmpty2 = do pure True -- Compare with chunks -(==~) :: Eq a => Get a -> Get a -> Property +(==~) :: (Eq a, Show a) => Get a -> Get a -> Property p1 ==~ p2 = conjoin - [ counterexample (show in0) $ R (runGetLazy p1 s) == R (runGetLazy p2 s) + [ let rl = runGetLazy p1 s + rr = runGetLazy p2 s + in counterexample (show (in0, n, s, rl, rr)) $ R rl == R rr | n <- [0 .. testLength] , let Chunks in0 = mkChunks n s = LB.fromChunks [ BS.pack c | c <- in0 ] @@ -254,6 +266,135 @@ alterDistr' p1 p2 p3 = y = buildGet p2 z = buildGet p3 +isolateLazyIsIncremental :: Assertion +isolateLazyIsIncremental = go (runGetPartial parser $ BS.replicate 11 0) + where + parser :: Get () + parser = isolateLazy 100 $ do + skip 10 + fail failStr + pure () + + failStr :: String + failStr = "no thanks" + + go :: Result () -> IO () + go r = case r of + Done () _ -> assertFailure "Impossible" + Fail failure _ -> unless (failStr `isInfixOf` failure) $ assertFailure "Wrong error!" + Partial cont -> assertFailure "Asked for more input!" + +isolateLazyLeavesRemainingBytes :: Assertion +isolateLazyLeavesRemainingBytes = go (runGetPartial parser $ BS.replicate 11 0) + where + parser :: Get () + parser = isolateLazy 100 $ do + skip 10 + fail failStr + pure () + + failStr :: String + failStr = "no thanks" + + go :: Result () -> IO () + go r = case r of + Done () _ -> assertFailure "Impossible" + Fail failure _ -> unless (failStr `isInfixOf` failure) $ assertFailure "Wrong error!" + Partial cont -> assertFailure "Asked for more input!" + +instance Arbitrary LB.ByteString where + arbitrary = LB.fromChunks . pure . BS.pack <$> arbitrary + +newtype IsolationRes a = IRes (Either String a) + deriving Show + +-- Sometimes lazy and strict isolations return different errors, +-- eg. when EOF is called before the end of an isolation which isn't prodided +-- enough input. +-- Strict sees it as a lack of bytes, Lazy sees it as a guard failure ("empty"). +instance Eq a => Eq (IsolationRes a) where + IRes a == IRes b = case (a, b) of + (Left e1, Left e2) -> True + _ -> a == b + +isolateAndIsolateLazy :: Int -> GetD -> LB.ByteString -> Property +isolateAndIsolateLazy n parser' bs + = IRes (runGetLazy (isolate n parser) bs) === IRes (runGetLazy (isolateLazy n parser) bs) + where + parser = buildGet parser' + +isolateIsNotIncremental :: Assertion +isolateIsNotIncremental = go (runGetPartial parser $ BS.replicate 11 0) + where + parser :: Get () + parser = isolate 100 $ do + skip 10 + fail failStr + pure () + + failStr :: String + failStr = "no thanks" + + go :: Result () -> IO () + go r = case r of + Done () _ -> assertFailure "Impossible" + Fail failure _ -> assertFailure $ "Strict isolate was incremental: " <> failure + Partial cont -> pure () + +isolate0 :: Assertion +isolate0 = do + runGet parseSucceed "hello" @?= Right (42, "hello") + first (const ()) (runGet parseFail "hello") @?= Left () + + where + parseSucceed :: Get (Int, BS.ByteString) + parseSucceed = do + a <- isolate 0 $ pure 42 + b <- getByteString 5 + pure (a, b) + + parseFail :: Get (Word8, BS.ByteString) + parseFail = do + a <- isolate 0 getWord8 + b <- getByteString 5 + pure (a, b) + +isolate2 :: Assertion +isolate2 = runGet parser "hello" @?= Right ("he", "llo") + where + parser :: Get (BS.ByteString, BS.ByteString) + parser = do + a <- isolate 2 $ getByteString 2 + b <- getByteString 3 + pure (a, b) + +testEnsure :: Assertion +testEnsure = do + runGet parser "hello" @?= Right (replicate 3 "hello") + where + parser = do + a <- ensure 0 + b <- ensure 2 + c <- ensure 5 + pure [a, b, c] + +-- Checks return values, leftovers, fails for continuations +assertResultsMatch :: Eq a => Result a -> (Maybe a, BS.ByteString) -> Assertion +assertResultsMatch r1 r2 = case (r1, r2) of + (Partial _, _) -> assertFailure "Continuation received" + (Done a1 bs1, (Just a2, bs2)) -> do + unless (a1 == a2) $ assertFailure "Result mismatch" + unless (bs1 == bs2) $ assertFailure $ "Success leftover mismatch: " ++ show (bs1, bs2) + (Fail msg1 bs1, (Nothing, bs2)) -> + unless (bs1 == bs2) $ assertFailure $ "Failure leftovers mismatch: " ++ show (bs1, bs2) + _ -> assertFailure "Different result types" + +isolateLazyDeterminesLeftovers :: Assertion +isolateLazyDeterminesLeftovers = do + assertResultsMatch (runGetPartial (isolateLazy 1 getWord8) "123") (Just $ toEnum $ fromEnum '1', "23") + assertResultsMatch (runGetPartial (isolateLazy 2 getWord8) "123") (Nothing, "3") + assertResultsMatch (runGetPartial (isolate 2 $ fail "no thanks" *> pure ()) "123") (Nothing, "12") + assertResultsMatch (runGetPartial (isolateLazy 2 $ fail "no thanks" *> pure ()) "123") (Nothing, "12") tests :: Test tests = testGroup "GetTests" @@ -264,7 +405,7 @@ tests = testGroup "GetTests" , testProperty "lazy - monad assoc" monadAssoc , testProperty "strict - monad assoc" monadAssoc' , testProperty "strict lazy - equality" eqStrictLazy - , testProperty "strict lazy - remaining equality"remainingStrictLazy + , testProperty "strict lazy - remaining equality" remainingStrictLazy , testProperty "lazy - two eof" eqEof , testProperty "strict - two eof" eqEof' , testProperty "lazy - alternative left Id" alterIdL @@ -275,4 +416,11 @@ tests = testGroup "GetTests" , testProperty "strict - alternative assoc" alterAssoc' , testProperty "lazy - alternative distr" alterDistr , testProperty "strict - alternative distr" alterDistr' + , testCase "isolate is not incremental" isolateIsNotIncremental + , testCase "ensure" testEnsure + , testCase "isolate 0" isolate0 + , testCase "isolate 2" isolate2 + , testCase "isolateLazy is incremental" isolateLazyIsIncremental + , testProperty "isolations are equivalent" isolateAndIsolateLazy + , testCase "isolateLazy determines leftovers" isolateLazyDeterminesLeftovers ]