From 160316d53c8a757695542b1071e251be41ae913a Mon Sep 17 00:00:00 2001 From: mmmray <142015632+mmmray@users.noreply.github.com> Date: Sat, 17 Aug 2024 13:01:58 +0200 Subject: [PATCH] SplitHTTP: Do not produce too large upload (#3691) --- transport/internet/splithttp/dialer.go | 46 +++++++++++++-- .../internet/splithttp/splithttp_test.go | 56 ++++++++++++++++++- transport/pipe/impl.go | 8 +++ transport/pipe/writer.go | 4 ++ 4 files changed, 108 insertions(+), 6 deletions(-) diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index 0d487b58d2b6..a95ab34b30e8 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -227,7 +227,11 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me httpClient := getHTTPClient(ctx, dest, streamSettings) - uploadPipeReader, uploadPipeWriter := pipe.New(pipe.WithSizeLimit(scMaxEachPostBytes.roll())) + maxUploadSize := scMaxEachPostBytes.roll() + // WithSizeLimit(0) will still allow single bytes to pass, and a lot of + // code relies on this behavior. Subtract 1 so that together with + // uploadWriter wrapper, exact size limits can be enforced + uploadPipeReader, uploadPipeWriter := pipe.New(pipe.WithSizeLimit(maxUploadSize - 1)) go func() { requestsLimiter := semaphore.New(int(scMaxConcurrentPosts.roll())) @@ -318,12 +322,13 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me }, } - // necessary in order to send larger chunks in upload - bufferedUploadPipeWriter := buf.NewBufferedWriter(uploadPipeWriter) - bufferedUploadPipeWriter.SetBuffered(false) + writer := uploadWriter{ + uploadPipeWriter, + maxUploadSize, + } conn := splitConn{ - writer: bufferedUploadPipeWriter, + writer: writer, reader: lazyDownload, remoteAddr: remoteAddr, localAddr: localAddr, @@ -331,3 +336,34 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me return stat.Connection(&conn), nil } + +// A wrapper around pipe that ensures the size limit is exactly honored. +// +// The MultiBuffer pipe accepts any single WriteMultiBuffer call even if that +// single MultiBuffer exceeds the size limit, and then starts blocking on the +// next WriteMultiBuffer call. This means that ReadMultiBuffer can return more +// bytes than the size limit. We work around this by splitting a potentially +// too large write up into multiple. +type uploadWriter struct { + *pipe.Writer + maxLen int32 +} + +func (w uploadWriter) Write(b []byte) (int, error) { + capacity := int(w.maxLen - w.Len()) + if capacity > 0 && capacity < len(b) { + b = b[:capacity] + } + + buffer := buf.New() + n, err := buffer.Write(b) + if err != nil { + return 0, err + } + + err = w.WriteMultiBuffer([]*buf.Buffer{buffer}) + if err != nil { + return 0, err + } + return n, nil +} diff --git a/transport/internet/splithttp/splithttp_test.go b/transport/internet/splithttp/splithttp_test.go index acb4addc4a40..30f92c7ffcb9 100644 --- a/transport/internet/splithttp/splithttp_test.go +++ b/transport/internet/splithttp/splithttp_test.go @@ -388,7 +388,7 @@ func Test_queryString(t *testing.T) { ctx := context.Background() streamSettings := &internet.MemoryStreamConfig{ ProtocolName: "splithttp", - ProtocolSettings: &Config{Path: "sh"}, + ProtocolSettings: &Config{Path: "sh?ed=2048"}, } conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) @@ -407,3 +407,57 @@ func Test_queryString(t *testing.T) { common.Must(conn.Close()) common.Must(listen.Close()) } + +func Test_maxUpload(t *testing.T) { + listenPort := tcp.PickPort() + streamSettings := &internet.MemoryStreamConfig{ + ProtocolName: "splithttp", + ProtocolSettings: &Config{ + Path: "/sh", + ScMaxEachPostBytes: &RandRangeConfig{ + From: 100, + To: 100, + }, + }, + } + + var uploadSize int + listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { + go func(c stat.Connection) { + defer c.Close() + var b [1024]byte + c.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := c.Read(b[:]) + if err != nil { + return + } + + uploadSize = n + + common.Must2(c.Write([]byte("Response"))) + }(conn) + }) + common.Must(err) + ctx := context.Background() + + conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) + + // send a slightly too large upload + var upload [101]byte + _, err = conn.Write(upload[:]) + common.Must(err) + + var b [1024]byte + n, _ := io.ReadFull(conn, b[:]) + fmt.Println("string is", n) + if string(b[:n]) != "Response" { + t.Error("response: ", string(b[:n])) + } + common.Must(conn.Close()) + + if uploadSize > 100 || uploadSize == 0 { + t.Error("incorrect upload size: ", uploadSize) + } + + common.Must(listen.Close()) +} diff --git a/transport/pipe/impl.go b/transport/pipe/impl.go index dbdb050ef368..8bf58a34e9e0 100644 --- a/transport/pipe/impl.go +++ b/transport/pipe/impl.go @@ -46,6 +46,14 @@ var ( errSlowDown = errors.New("slow down") ) +func (p *pipe) Len() int32 { + data := p.data + if data == nil { + return 0 + } + return data.Len() +} + func (p *pipe) getState(forRead bool) error { switch p.state { case open: diff --git a/transport/pipe/writer.go b/transport/pipe/writer.go index 8230ec7607c3..0a192ca05523 100644 --- a/transport/pipe/writer.go +++ b/transport/pipe/writer.go @@ -19,6 +19,10 @@ func (w *Writer) Close() error { return w.pipe.Close() } +func (w *Writer) Len() int32 { + return w.pipe.Len() +} + // Interrupt implements common.Interruptible. func (w *Writer) Interrupt() { w.pipe.Interrupt()