Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Add zero and negative byte length checks to ensure #110

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions src/Data/Serialize/Get.hs
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,15 @@ runGetLazyState m lstr = case runGetLazy' m lstr of

-- | If at least @n@ bytes of input are available, return the current
-- input, otherwise fail.
{-# INLINE ensure #-}
ensure :: Int -> Get B.ByteString
ensure n0 = n0 `seq` Get $ \ s0 b0 m0 w0 kf ks -> let
ensure n
| n < 0 = fail "Attempted to ensure negative amount of bytes"
| n == 0 = pure mempty
| otherwise = ensure' n

{-# INLINE ensure #-}
ensure' :: Int -> Get B.ByteString
ensure' n0 = n0 `seq` Get $ \ s0 b0 m0 w0 kf ks -> let
n' = n0 - B.length s0
in if n' <= 0
then ks s0 b0 m0 w0 s0
Expand Down Expand Up @@ -405,9 +411,16 @@ ensure n0 = n0 `seq` Get $ \ s0 b0 m0 w0 kf ks -> let
-- | Isolate an action to operating within a fixed block of bytes. The action
-- is required to consume all the bytes that it is isolated to.
isolate :: Int -> Get a -> Get a
isolate 0 m = do
rest <- get
cur <- bytesRead
put mempty cur
a <- m
put rest cur
pure a
isolate n m = do
M.when (n < 0) (fail "Attempted to isolate a negative number of bytes")
s <- ensure n
s <- ensure' n
let (s',rest) = B.splitAt n s
cur <- bytesRead
put s' cur
Expand All @@ -424,8 +437,10 @@ failDesc err = do

-- | Skip ahead @n@ bytes. Fails if fewer than @n@ bytes are available.
skip :: Int -> Get ()
skip 0 = pure ()
skip n = do
s <- ensure n
M.when (n < 0) (fail "Attempted to skip a negative number of bytes")
s <- ensure' n
cur <- bytesRead
put (B.drop n s) (cur + n)

Expand Down Expand Up @@ -520,15 +535,17 @@ getShortByteString n = do

-- | Pull @n@ bytes from the input, as a strict ByteString.
getBytes :: Int -> Get B.ByteString
getBytes n | n < 0 = fail "getBytes: negative length requested"
getBytes n = do
s <- ensure n
let consume = B.unsafeTake n s
rest = B.unsafeDrop n s
-- (consume,rest) = B.splitAt n s
cur <- bytesRead
put rest (cur + n)
return consume
getBytes n
| n < 0 = fail "getBytes: negative length requested"
| n == 0 = pure mempty
| otherwise = do
s <- ensure' n
let consume = B.unsafeTake n s
rest = B.unsafeDrop n s
-- (consume,rest) = B.splitAt n s
cur <- bytesRead
put rest (cur + n)
return consume
{-# INLINE getBytes #-}


Expand Down