From 4ea5442833add2add7763cc6e811666b1cd9fefd Mon Sep 17 00:00:00 2001 From: Mat Trudel Date: Wed, 18 Dec 2024 10:36:53 -0500 Subject: [PATCH] Fixup POST pipeline support (#442) * Add test coverage for pipelined POST requests * Improve HTTP/1's bytes_remaining semantics Previously this field represented the bytes still to read off the wire, as opposed to the bytes still required per the request's content length. This made it hard to support cases where we *wanted* to maintain a buffer of pipelined requests after the current request's body (previously, we just errored loudly on this in the assumption that it represented an intentionally misbehaving client). Fix this by using an `undefined_content_length` parameter on HTTP/1 requests that instead represents the pending bytes *on this request* so that we can more easily handle the case where the socket buffer contains bytes belonging to the subsequent request --- lib/bandit/http1/socket.ex | 66 ++++++++++++++++-------------- test/bandit/http1/request_test.exs | 48 +++++++++++----------- 2 files changed, 59 insertions(+), 55 deletions(-) diff --git a/lib/bandit/http1/socket.ex b/lib/bandit/http1/socket.ex index e8130c8b..57aec4a5 100644 --- a/lib/bandit/http1/socket.ex +++ b/lib/bandit/http1/socket.ex @@ -12,7 +12,7 @@ defmodule Bandit.HTTP1.Socket do buffer: <<>>, read_state: :unread, write_state: :unsent, - bytes_remaining: nil, + unread_content_length: nil, body_encoding: nil, version: :"HTTP/1.0", send_buffer: nil, @@ -31,7 +31,7 @@ defmodule Bandit.HTTP1.Socket do buffer: iodata(), read_state: read_state(), write_state: write_state(), - bytes_remaining: non_neg_integer() | :chunked | nil, + unread_content_length: non_neg_integer() | :chunked | nil, body_encoding: nil | binary(), version: nil | :"HTTP/1.1" | :"HTTP/1.0", send_buffer: iolist(), @@ -49,20 +49,19 @@ defmodule Bandit.HTTP1.Socket do def read_headers(%@for{read_state: :unread} = socket) do {method, request_target, socket} = do_read_request_line!(socket) {headers, socket} = do_read_headers!(socket) - body_size = get_content_length!(headers) + content_length = get_content_length!(headers) body_encoding = Bandit.Headers.get_header(headers, "transfer-encoding") connection = Bandit.Headers.get_header(headers, "connection") keepalive = should_keepalive?(socket.version, connection) socket = %{socket | keepalive: keepalive} - case {body_size, body_encoding} do + case {content_length, body_encoding} do {nil, nil} -> # No body, so just go straight to 'read' {:ok, method, request_target, headers, %{socket | read_state: :read}} - {body_size, nil} -> - bytes_remaining = body_size - byte_size(socket.buffer) - socket = %{socket | read_state: :headers_read, bytes_remaining: bytes_remaining} + {content_length, nil} -> + socket = %{socket | read_state: :headers_read, unread_content_length: content_length} {:ok, method, request_target, headers, socket} {nil, body_encoding} -> @@ -173,17 +172,19 @@ defmodule Bandit.HTTP1.Socket do defp should_keepalive?(_, _), do: false def read_data( - %@for{read_state: :headers_read, bytes_remaining: bytes_remaining} = socket, + %@for{read_state: :headers_read, unread_content_length: unread_content_length} = socket, opts ) - when is_number(bytes_remaining) do - {to_return, buffer, bytes_remaining} = - do_read_content_length_data!(socket.socket, socket.buffer, bytes_remaining, opts) + when is_number(unread_content_length) do + {to_return, buffer, remaining_unread_content_length} = + do_read_content_length_data!(socket.socket, socket.buffer, unread_content_length, opts) - if byte_size(buffer) == 0 && bytes_remaining == 0 do - {:ok, to_return, %{socket | read_state: :read, buffer: <<>>, bytes_remaining: 0}} + socket = %{socket | buffer: buffer, unread_content_length: remaining_unread_content_length} + + if remaining_unread_content_length == 0 do + {:ok, to_return, %{socket | read_state: :read}} else - {:more, to_return, %{socket | buffer: buffer, bytes_remaining: bytes_remaining}} + {:more, to_return, socket} end end @@ -207,32 +208,35 @@ defmodule Bandit.HTTP1.Socket do def read_data(%@for{} = socket, _opts), do: {:ok, <<>>, socket} @dialyzer {:no_improper_lists, do_read_content_length_data!: 4} - defp do_read_content_length_data!(socket, buffer, bytes_remaining, opts) do - max_desired_bytes = Keyword.get(opts, :length, 8_000_000) + defp do_read_content_length_data!(socket, buffer, unread_content_length, opts) do + max_to_return = min(unread_content_length, Keyword.get(opts, :length, 8_000_000)) cond do - bytes_remaining < 0 -> - # We have read more bytes than content-length suggested should have been sent. This is - # veering into request smuggling territory and should never happen with a well behaved - # client. The safest thing to do is just error - request_error!("Excess body read") + max_to_return == 0 -> + # We have already satisfied our content length + {<<>>, buffer, unread_content_length} - byte_size(buffer) >= max_desired_bytes || bytes_remaining == 0 -> + byte_size(buffer) >= max_to_return -> # We can satisfy the read request entirely from our buffer - bytes_to_return = min(max_desired_bytes, byte_size(buffer)) - <> = buffer - {to_return, rest, bytes_remaining} + <> = buffer + {to_return, rest, unread_content_length - max_to_return} - true -> + byte_size(buffer) < max_to_return -> # We need to read off the wire - bytes_to_read = min(max_desired_bytes - byte_size(buffer), bytes_remaining) read_size = Keyword.get(opts, :read_length, 1_000_000) read_timeout = Keyword.get(opts, :read_timeout) - iolist = read!(socket, bytes_to_read, [], read_size, read_timeout) - to_return = IO.iodata_to_binary([buffer | iolist]) - bytes_remaining = bytes_remaining - (byte_size(to_return) - byte_size(buffer)) - {to_return, <<>>, bytes_remaining} + to_return = + read!(socket, max_to_return - byte_size(buffer), [buffer], read_size, read_timeout) + |> IO.iodata_to_binary() + + # We may have read more than we need to return + if byte_size(to_return) >= max_to_return do + <> = to_return + {to_return, rest, unread_content_length - max_to_return} + else + {to_return, <<>>, unread_content_length - byte_size(to_return)} + end end end diff --git a/test/bandit/http1/request_test.exs b/test/bandit/http1/request_test.exs index cfeeefe5..976c87da 100644 --- a/test/bandit/http1/request_test.exs +++ b/test/bandit/http1/request_test.exs @@ -247,7 +247,7 @@ defmodule HTTP1RequestTest do Transport.send( client, - String.duplicate("GET /hello_world HTTP/1.1\r\nHost: localhost\r\n\r\n", 50) + String.duplicate("GET /send_ok HTTP/1.1\r\nHost: localhost\r\n\r\n", 50) ) for _ <- 1..50 do @@ -258,6 +258,29 @@ defmodule HTTP1RequestTest do end end + test "handles pipeline requests with unread POST bodies", context do + client = SimpleHTTP1Client.tcp_client(context) + + Transport.send( + client, + String.duplicate( + "POST /send_ok HTTP/1.1\r\nHost: localhost\r\nContent-Length:3\r\n\r\nABC", + 50 + ) + ) + + for _ <- 1..50 do + # Need to read the exact size of the expected response because SimpleHTTP1Client + # doesn't track 'rest' bytes and ends up throwing a bunch of responses on the floor + {:ok, bytes} = Transport.recv(client, 152) + assert({:ok, "200 OK", _, _} = SimpleHTTP1Client.parse_response(client, bytes)) + end + end + + def send_ok(conn) do + send_resp(conn, 200, "OK") + end + test "closes connection after max_requests is reached", context do context = http_server(context, http_1_options: [max_requests: 3]) client = SimpleHTTP1Client.tcp_client(context) @@ -1002,29 +1025,6 @@ defmodule HTTP1RequestTest do raise "Shouldn't get here" end - test "handles the case where the declared content length is less than what is sent", - context do - output = - capture_log(fn -> - client = SimpleHTTP1Client.tcp_client(context) - - Transport.send( - client, - "POST /long_body HTTP/1.1\r\nhost: localhost\r\ncontent-length: 3\r\n\r\nABCDE" - ) - - assert {:ok, "400 Bad Request", _, ""} = SimpleHTTP1Client.recv_reply(client) - Process.sleep(100) - end) - - assert output =~ "(Bandit.HTTPError) Excess body read" - end - - def long_body(conn) do - Plug.Conn.read_body(conn) - raise "should not get here" - end - test "reading request body multiple times works as expected", context do response = Req.post!(context.req, url: "/multiple_body_read", body: "OK")