1
+ {-# LANGUAGE BangPatterns #-}
2
+
1
3
-- | Utility functions for TCP sockets
2
4
module Network.Transport.TCP.Internal
3
5
( ControlHeader (.. )
@@ -8,6 +10,7 @@ module Network.Transport.TCP.Internal
8
10
, decodeConnectionRequestResponse
9
11
, forkServer
10
12
, recvWithLength
13
+ , recvWithLengthFold
11
14
, recvExact
12
15
, recvWord32
13
16
, encodeWord32
@@ -59,7 +62,7 @@ import qualified Network.Socket.ByteString as NBS (recv)
59
62
import Control.Concurrent (ThreadId )
60
63
import Data.Word (Word32 )
61
64
62
- import Control.Monad (forever , when )
65
+ import Control.Monad (forever , when , unless )
63
66
import Control.Exception (SomeException , catch , bracketOnError , throwIO , mask_ )
64
67
import Control.Applicative ((<$>) , (<*>) )
65
68
import Data.Word (Word32 )
@@ -180,20 +183,44 @@ forkServer host port backlog reuseAddr terminationHandler requestHandler = do
180
183
(tryCloseSocket . fst )
181
184
(requestHandler . fst )
182
185
183
- -- | Read a length and then a payload of that length, subject to an optional
184
- -- limit on the length.
185
- recvWithLength :: Maybe Word32 -> N. Socket -> IO [ByteString ]
186
- recvWithLength mlimit sock = case mlimit of
187
- Nothing -> recvWord32 sock >>= recvExact sock
188
- Just limit -> do
189
- length <- recvWord32 sock
190
- when (length > limit) $
191
- throwIO (userError " recvWithLength: limit exceeded" )
192
- recvExact sock length
186
+ -- | Read a length, then 1 or more payloads each less than some maximum
187
+ -- length in bytes, such that the sum of their lengths is the length that was
188
+ -- read.
189
+ recvWithLengthFold
190
+ :: Word32 -- ^ Maximum total size.
191
+ -> Word32 -- ^ Maximum chunk size.
192
+ -> N. Socket
193
+ -> t -- ^ Start element for the fold.
194
+ -> ([ByteString ] -> t -> IO t ) -- ^ Run this every time we get data of at
195
+ -- most the maximum size.
196
+ -> IO t
197
+ recvWithLengthFold maxSize maxChunk sock base folder = do
198
+ len <- recvWord32 sock
199
+ when (len > maxSize) $
200
+ throwIO (userError " recvWithLengthFold: limit exceeded" )
201
+ loop base len
202
+ where
203
+ loop ! base ! total = do
204
+ (bs, received) <- recvExact sock (min maxChunk total)
205
+ base' <- folder bs base
206
+ let remaining = total - received
207
+ when (received > total) $ throwIO (userError " recvWithLengthFold: got more bytes than requested" )
208
+ if remaining == 0
209
+ then return base'
210
+ else loop base' remaining
211
+
212
+ -- | Read a length and then a payload of that length
213
+ recvWithLength
214
+ :: Word32 -- ^ Maximum total size.
215
+ -> N. Socket
216
+ -> IO [ByteString ]
217
+ recvWithLength maxSize sock = fmap (concat . reverse ) $
218
+ recvWithLengthFold maxSize maxBound sock [] $
219
+ \ bs lst -> return (bs : lst)
193
220
194
221
-- | Receive a 32-bit unsigned integer
195
222
recvWord32 :: N. Socket -> IO Word32
196
- recvWord32 = fmap (decodeWord32 . BS. concat ) . flip recvExact 4
223
+ recvWord32 = fmap (decodeWord32 . BS. concat . fst ) . flip recvExact 4
197
224
198
225
-- | Close a socket, ignoring I/O exceptions.
199
226
tryCloseSocket :: N. Socket -> IO ()
@@ -204,16 +231,22 @@ tryCloseSocket sock = void . tryIO $
204
231
--
205
232
-- Throws an I/O exception if the socket closes before the specified
206
233
-- number of bytes could be read
207
- recvExact :: N. Socket -- ^ Socket to read from
208
- -> Word32 -- ^ Number of bytes to read
209
- -> IO [ByteString ]
234
+ recvExact :: N. Socket -- ^ Socket to read from
235
+ -> Word32 -- ^ Number of bytes to read
236
+ -> IO ( [ByteString ], Word32 ) -- ^ Data and number of bytes read
210
237
recvExact _ len | len < 0 = throwIO (userError " recvExact: Negative length" )
211
- recvExact sock len = go [] len
238
+ recvExact sock len = go [] 0 len
212
239
where
213
- go :: [ByteString ] -> Word32 -> IO [ByteString ]
214
- go acc 0 = return (reverse acc)
215
- go acc l = do
240
+ go :: [ByteString ] -> Word32 -> Word32 -> IO ( [ByteString ], Word32 )
241
+ go acc ! n 0 = return (reverse acc, n )
242
+ go acc ! n l = do
216
243
bs <- NBS. recv sock (fromIntegral l `min` smallChunkSize)
217
244
if BS. null bs
218
245
then throwIO (userError " recvExact: Socket closed" )
219
- else go (bs : acc) (l - fromIntegral (BS. length bs))
246
+ else do
247
+ let received = fromIntegral (BS. length bs)
248
+ remaining = l - received
249
+ total = n + received
250
+ -- Check for underflow. Shouldn't be possible but let's make sure.
251
+ when (received > l) $ throwIO (userError " recvExact: got more bytes than requested" )
252
+ go (bs : acc) total remaining
0 commit comments