Skip to content

Commit 3d5994c

Browse files
Merge pull request #52 from avieth/avieth/limit_length
Optional limits on address and data length
2 parents 6f8bf9d + 270aa07 commit 3d5994c

File tree

3 files changed

+158
-27
lines changed

3 files changed

+158
-27
lines changed

src/Network/Transport/TCP.hs

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ import Network.Transport.TCP.Internal
5656
, decodeConnectionRequestResponse
5757
, forkServer
5858
, recvWithLength
59+
, recvWithLengthFold
5960
, recvWord32
6061
, encodeWord32
6162
, tryCloseSocket
@@ -482,6 +483,18 @@ data TCPParameters = TCPParameters {
482483
, transportConnectTimeout :: Maybe Int
483484
-- | Create a QDisc for an EndPoint.
484485
, tcpNewQDisc :: forall t . IO (QDisc t)
486+
-- | Maximum length (in bytes) for a peer's address.
487+
-- If a peer attempts to send an address of length exceeding the limit,
488+
-- the connection will be refused (socket will close).
489+
, tcpMaxAddressLength :: Word32
490+
-- | Maximum length (in bytes) to receive from a peer.
491+
-- If a peer attempts to send data on a lightweight connection exceeding
492+
-- the limit, the heavyweight connection which carries that lightweight
493+
-- connection will go down. The peer and the local node will get an
494+
-- EventConnectionLost.
495+
, tcpMaxReceiveLength :: Word32
496+
-- | Maximum length (in bytes) of a 'Received' event payload.
497+
, tcpMaxChunkSize :: Word32
485498
}
486499

487500
-- | Internal functionality we expose for unit testing
@@ -588,6 +601,9 @@ defaultTCPParameters = TCPParameters {
588601
, tcpUserTimeout = Nothing
589602
, tcpNewQDisc = simpleUnboundedQDisc
590603
, transportConnectTimeout = Nothing
604+
, tcpMaxAddressLength = maxBound
605+
, tcpMaxReceiveLength = maxBound
606+
, tcpMaxChunkSize = maxBound
591607
}
592608

593609
--------------------------------------------------------------------------------
@@ -864,7 +880,9 @@ handleConnectionRequest transport sock = handle handleException $ do
864880
forM_ (tcpUserTimeout $ transportParams transport) $
865881
N.setSocketOption sock N.UserTimeout
866882
ourEndPointId <- recvWord32 sock
867-
theirAddress <- EndPointAddress . BS.concat <$> recvWithLength sock
883+
let maxAddressLength = tcpMaxAddressLength $ transportParams transport
884+
theirAddress <- EndPointAddress . BS.concat <$>
885+
recvWithLength maxAddressLength sock
868886
let ourAddress = encodeEndPointAddress (transportHost transport)
869887
(transportPort transport)
870888
ourEndPointId
@@ -914,7 +932,7 @@ handleConnectionRequest transport sock = handle handleException $ do
914932
-- been recorded as part of the remote endpoint. Either way, we no longer
915933
-- have to worry about closing the socket on receiving an asynchronous
916934
-- exception from this point forward.
917-
forM_ mEndPoint $ handleIncomingMessages . (,) ourEndPoint
935+
forM_ mEndPoint $ handleIncomingMessages (transportParams transport) . (,) ourEndPoint
918936

919937
handleException :: SomeException -> IO ()
920938
handleException ex = do
@@ -957,8 +975,8 @@ handleConnectionRequest transport sock = handle handleException $ do
957975
--
958976
-- Returns only if the remote party closes the socket or if an error occurs.
959977
-- This runs in a thread that will never be killed.
960-
handleIncomingMessages :: EndPointPair -> IO ()
961-
handleIncomingMessages (ourEndPoint, theirEndPoint) = do
978+
handleIncomingMessages :: TCPParameters -> EndPointPair -> IO ()
979+
handleIncomingMessages params (ourEndPoint, theirEndPoint) = do
962980
mSock <- withMVar theirState $ \st ->
963981
case st of
964982
RemoteEndPointInvalid _ ->
@@ -1175,7 +1193,8 @@ handleIncomingMessages (ourEndPoint, theirEndPoint) = do
11751193
-- overhead
11761194
readMessage :: N.Socket -> LightweightConnectionId -> IO ()
11771195
readMessage sock lcid =
1178-
recvWithLength sock >>= qdiscEnqueue' ourQueue theirAddr . Received (connId lcid)
1196+
recvWithLengthFold recvLimit chunkLimit sock () $ \bs _ ->
1197+
qdiscEnqueue' ourQueue theirAddr (Received (connId lcid) bs)
11791198

11801199
-- Stop probing a connection as a result of receiving a probe ack.
11811200
stopProbing :: IO ()
@@ -1190,6 +1209,8 @@ handleIncomingMessages (ourEndPoint, theirEndPoint) = do
11901209
ourQueue = localQueue ourEndPoint
11911210
theirState = remoteState theirEndPoint
11921211
theirAddr = remoteAddress theirEndPoint
1212+
recvLimit = tcpMaxReceiveLength params
1213+
chunkLimit = tcpMaxChunkSize params
11931214

11941215
-- Deal with a premature exit
11951216
prematureExit :: N.Socket -> IOException -> IO ()
@@ -1365,7 +1386,7 @@ setupRemoteEndPoint params (ourEndPoint, theirEndPoint) connTimeout = do
13651386
return False
13661387

13671388
when didAccept $ void $ forkIO $
1368-
handleIncomingMessages (ourEndPoint, theirEndPoint)
1389+
handleIncomingMessages params (ourEndPoint, theirEndPoint)
13691390
return $ either (const Nothing) (Just . snd) result
13701391
where
13711392
ourAddress = localAddress ourEndPoint

src/Network/Transport/TCP/Internal.hs

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
{-# LANGUAGE BangPatterns #-}
2+
13
-- | Utility functions for TCP sockets
24
module Network.Transport.TCP.Internal
35
( ControlHeader(..)
@@ -8,6 +10,7 @@ module Network.Transport.TCP.Internal
810
, decodeConnectionRequestResponse
911
, forkServer
1012
, recvWithLength
13+
, recvWithLengthFold
1114
, recvExact
1215
, recvWord32
1316
, encodeWord32
@@ -59,9 +62,10 @@ import qualified Network.Socket.ByteString as NBS (recv)
5962
import Control.Concurrent (ThreadId)
6063
import Data.Word (Word32)
6164

62-
import Control.Monad (forever, when)
65+
import Control.Monad (forever, when, unless)
6366
import Control.Exception (SomeException, catch, bracketOnError, throwIO, mask_)
6467
import Control.Applicative ((<$>), (<*>))
68+
import Data.Word (Word32)
6569
import Data.ByteString (ByteString)
6670
import qualified Data.ByteString as BS (length, concat, null)
6771
import Data.ByteString.Lazy.Internal (smallChunkSize)
@@ -179,13 +183,44 @@ forkServer host port backlog reuseAddr terminationHandler requestHandler = do
179183
(tryCloseSocket . fst)
180184
(requestHandler . fst)
181185

186+
-- | Read a length, then 1 or more payloads each less than some maximum
187+
-- length in bytes, such that the sum of their lengths is the length that was
188+
-- read.
189+
recvWithLengthFold
190+
:: Word32 -- ^ Maximum total size.
191+
-> Word32 -- ^ Maximum chunk size.
192+
-> N.Socket
193+
-> t -- ^ Start element for the fold.
194+
-> ([ByteString] -> t -> IO t) -- ^ Run this every time we get data of at
195+
-- most the maximum size.
196+
-> IO t
197+
recvWithLengthFold maxSize maxChunk sock base folder = do
198+
len <- recvWord32 sock
199+
when (len > maxSize) $
200+
throwIO (userError "recvWithLengthFold: limit exceeded")
201+
loop base len
202+
where
203+
loop !base !total = do
204+
(bs, received) <- recvExact sock (min maxChunk total)
205+
base' <- folder bs base
206+
let remaining = total - received
207+
when (received > total) $ throwIO (userError "recvWithLengthFold: got more bytes than requested")
208+
if remaining == 0
209+
then return base'
210+
else loop base' remaining
211+
182212
-- | Read a length and then a payload of that length
183-
recvWithLength :: N.Socket -> IO [ByteString]
184-
recvWithLength sock = recvWord32 sock >>= recvExact sock
213+
recvWithLength
214+
:: Word32 -- ^ Maximum total size.
215+
-> N.Socket
216+
-> IO [ByteString]
217+
recvWithLength maxSize sock = fmap (concat . reverse) $
218+
recvWithLengthFold maxSize maxBound sock [] $
219+
\bs lst -> return (bs : lst)
185220

186221
-- | Receive a 32-bit unsigned integer
187222
recvWord32 :: N.Socket -> IO Word32
188-
recvWord32 = fmap (decodeWord32 . BS.concat) . flip recvExact 4
223+
recvWord32 = fmap (decodeWord32 . BS.concat . fst) . flip recvExact 4
189224

190225
-- | Close a socket, ignoring I/O exceptions.
191226
tryCloseSocket :: N.Socket -> IO ()
@@ -196,16 +231,22 @@ tryCloseSocket sock = void . tryIO $
196231
--
197232
-- Throws an I/O exception if the socket closes before the specified
198233
-- number of bytes could be read
199-
recvExact :: N.Socket -- ^ Socket to read from
200-
-> Word32 -- ^ Number of bytes to read
201-
-> IO [ByteString]
234+
recvExact :: N.Socket -- ^ Socket to read from
235+
-> Word32 -- ^ Number of bytes to read
236+
-> IO ([ByteString], Word32) -- ^ Data and number of bytes read
202237
recvExact _ len | len < 0 = throwIO (userError "recvExact: Negative length")
203-
recvExact sock len = go [] len
238+
recvExact sock len = go [] 0 len
204239
where
205-
go :: [ByteString] -> Word32 -> IO [ByteString]
206-
go acc 0 = return (reverse acc)
207-
go acc l = do
240+
go :: [ByteString] -> Word32 -> Word32 -> IO ([ByteString], Word32)
241+
go acc !n 0 = return (reverse acc, n)
242+
go acc !n l = do
208243
bs <- NBS.recv sock (fromIntegral l `min` smallChunkSize)
209244
if BS.null bs
210245
then throwIO (userError "recvExact: Socket closed")
211-
else go (bs : acc) (l - fromIntegral (BS.length bs))
246+
else do
247+
let received = fromIntegral (BS.length bs)
248+
remaining = l - received
249+
total = n + received
250+
-- Check for underflow. Shouldn't be possible but let's make sure.
251+
when (received > l) $ throwIO (userError "recvExact: got more bytes than requested")
252+
go (bs : acc) total remaining

tests/TestTCP.hs

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import Network.Transport.TCP ( createTransport
1717
, createTransportExposeInternals
1818
, TransportInternals(..)
1919
, encodeEndPointAddress
20+
, TCPParameters(..)
2021
, defaultTCPParameters
2122
, LightweightConnectionId
2223
)
@@ -164,7 +165,7 @@ testEarlyDisconnect = do
164165
(clientPort, _) <- forkServer "127.0.0.1" "0" 5 True throwIO $ \sock -> do
165166
-- Initial setup
166167
0 <- recvWord32 sock
167-
_ <- recvWithLength sock
168+
_ <- recvWithLength maxBound sock
168169
sendMany sock [encodeWord32 (encodeConnectionRequestResponse ConnectionRequestAccepted)]
169170

170171
-- Server opens a logical connection
@@ -173,7 +174,7 @@ testEarlyDisconnect = do
173174

174175
-- Server sends a message
175176
1024 <- recvWord32 sock
176-
["ping"] <- recvWithLength sock
177+
["ping"] <- recvWithLength maxBound sock
177178

178179
-- Reply
179180
sendMany sock [
@@ -276,7 +277,7 @@ testEarlyCloseSocket = do
276277
(clientPort, _) <- forkServer "127.0.0.1" "0" 5 True throwIO $ \sock -> do
277278
-- Initial setup
278279
0 <- recvWord32 sock
279-
_ <- recvWithLength sock
280+
_ <- recvWithLength maxBound sock
280281
sendMany sock [encodeWord32 (encodeConnectionRequestResponse ConnectionRequestAccepted)]
281282

282283
-- Server opens a logical connection
@@ -285,7 +286,7 @@ testEarlyCloseSocket = do
285286

286287
-- Server sends a message
287288
1024 <- recvWord32 sock
288-
["ping"] <- recvWithLength sock
289+
["ping"] <- recvWithLength maxBound sock
289290

290291
-- Reply
291292
sendMany sock [
@@ -619,7 +620,7 @@ testReconnect = do
619620
(serverPort, _) <- forkServer "127.0.0.1" "0" 5 True throwIO $ \sock -> do
620621
-- Accept the connection
621622
Right 0 <- tryIO $ recvWord32 sock
622-
Right _ <- tryIO $ recvWithLength sock
623+
Right _ <- tryIO $ recvWithLength maxBound sock
623624

624625
-- The first time we close the socket before accepting the logical connection
625626
count <- modifyMVar counter $ \i -> return (i + 1, i)
@@ -638,7 +639,7 @@ testReconnect = do
638639
-- Client sends a message
639640
Right connId' <- tryIO $ (recvWord32 sock :: IO LightweightConnectionId)
640641
True <- return $ connId == connId'
641-
Right ["ping"] <- tryIO $ recvWithLength sock
642+
Right ["ping"] <- tryIO $ recvWithLength maxBound sock
642643
putMVar serverDone ()
643644

644645
Right () <- tryIO $ N.sClose sock
@@ -711,15 +712,15 @@ testUnidirectionalError = do
711712
-- would shutdown the socket in the other direction)
712713
void . (try :: IO () -> IO (Either SomeException ())) $ do
713714
0 <- recvWord32 sock
714-
_ <- recvWithLength sock
715+
_ <- recvWithLength maxBound sock
715716
() <- sendMany sock [encodeWord32 (encodeConnectionRequestResponse ConnectionRequestAccepted)]
716717

717718
Just CreatedNewConnection <- decodeControlHeader <$> recvWord32 sock
718719
connId <- recvWord32 sock :: IO LightweightConnectionId
719720

720721
connId' <- recvWord32 sock :: IO LightweightConnectionId
721722
True <- return $ connId == connId'
722-
["ping"] <- recvWithLength sock
723+
["ping"] <- recvWithLength maxBound sock
723724
putMVar serverGotPing ()
724725

725726
-- Client
@@ -831,10 +832,77 @@ testUseRandomPort = do
831832
putMVar testDone ()
832833
takeMVar testDone
833834

835+
-- | Verify that if a peer sends an address or data which exceeds the maximum
836+
-- length, that peer's connection will be terminated, but other peers will
837+
-- not be affected.
838+
testMaxLength :: IO ()
839+
testMaxLength = do
840+
841+
Right serverTransport <- createTransport "127.0.0.1" "9998" ((,) "127.0.0.1") $ defaultTCPParameters {
842+
-- 17 bytes should fit every valid address at 127.0.0.1.
843+
-- Port is at most 5 bytes (65536) and id is a base-10 Word32 so
844+
-- at most 10 bytes. We'll have one client with a 5-byte port to push it
845+
-- over the chosen limit of 16
846+
tcpMaxAddressLength = 16
847+
, tcpMaxReceiveLength = 8
848+
}
849+
Right goodClientTransport <- createTransport "127.0.0.1" "9999" ((,) "127.0.0.1") defaultTCPParameters
850+
Right badClientTransport <- createTransport "127.0.0.1" "10000" ((,) "127.0.0.1") defaultTCPParameters
851+
852+
serverAddress <- newEmptyMVar
853+
testDone <- newEmptyMVar
854+
goodClientConnected <- newEmptyMVar
855+
goodClientDone <- newEmptyMVar
856+
badClientDone <- newEmptyMVar
857+
858+
forkTry $ do
859+
Right serverEp <- newEndPoint serverTransport
860+
putMVar serverAddress (address serverEp)
861+
readMVar badClientDone
862+
ConnectionOpened _ _ _ <- receive serverEp
863+
Received _ _ <- receive serverEp
864+
-- Will lose the connection when the good client sends 9 bytes.
865+
ErrorEvent (TransportError (EventConnectionLost _) _) <- receive serverEp
866+
readMVar goodClientDone
867+
putMVar testDone ()
868+
869+
forkTry $ do
870+
Right badClientEp <- newEndPoint badClientTransport
871+
address <- readMVar serverAddress
872+
-- Wait until the good client connects, then try to connect. It'll fail,
873+
-- but the good client should still be OK.
874+
readMVar goodClientConnected
875+
Left (TransportError ConnectFailed _)
876+
<- connect badClientEp address ReliableOrdered defaultConnectHints
877+
closeEndPoint badClientEp
878+
putMVar badClientDone ()
879+
880+
forkTry $ do
881+
Right goodClientEp <- newEndPoint goodClientTransport
882+
address <- readMVar serverAddress
883+
Right conn <- connect goodClientEp address ReliableOrdered defaultConnectHints
884+
putMVar goodClientConnected ()
885+
-- Wait until the bad client has tried and failed to connect before
886+
-- attempting a send, to ensure that its failure did not affect us.
887+
readMVar badClientDone
888+
Right () <- send conn ["00000000"]
889+
-- The send which breaches the limit does not appear to fail, but the
890+
-- (heavyweight) connection is now severed. We can reliably determine that
891+
-- by receiving.
892+
Right () <- send conn ["000000000"]
893+
ErrorEvent (TransportError (EventConnectionLost _) _) <- receive goodClientEp
894+
closeEndPoint goodClientEp
895+
putMVar goodClientDone ()
896+
897+
readMVar testDone
898+
closeTransport badClientTransport
899+
closeTransport goodClientTransport
900+
closeTransport serverTransport
901+
834902
main :: IO ()
835903
main = do
836904
tcpResult <- tryIO $ runTests
837-
[ ("Use random port" , testUseRandomPort)
905+
[ ("Use random port", testUseRandomPort)
838906
, ("EarlyDisconnect", testEarlyDisconnect)
839907
, ("EarlyCloseSocket", testEarlyCloseSocket)
840908
, ("IgnoreCloseSocket", testIgnoreCloseSocket)
@@ -847,6 +915,7 @@ main = do
847915
, ("Reconnect", testReconnect)
848916
, ("UnidirectionalError", testUnidirectionalError)
849917
, ("InvalidCloseConnection", testInvalidCloseConnection)
918+
, ("MaxLength", testMaxLength)
850919
]
851920
-- Run the generic tests even if the TCP specific tests failed..
852921
testTransport (either (Left . show) (Right) <$>

0 commit comments

Comments
 (0)