Skip to content

Commit

Permalink
Move tracking of content-length into stream process
Browse files Browse the repository at this point in the history
  • Loading branch information
mtrudel committed Jan 8, 2024
1 parent 9fa5d65 commit d28019a
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 69 deletions.
62 changes: 43 additions & 19 deletions lib/bandit/http2/adapter.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ defmodule Bandit.HTTP2.Adapter do
end_stream: false,
method: nil,
content_encoding: nil,
pending_content_length: nil,
metrics: %{},
opts: []

Expand All @@ -23,6 +24,7 @@ defmodule Bandit.HTTP2.Adapter do
end_stream: boolean(),
method: Plug.Conn.method() | nil,
content_encoding: String.t() | nil,
pending_content_length: non_neg_integer() | nil,
metrics: map(),
opts: keyword()
}
Expand Down Expand Up @@ -75,34 +77,56 @@ defmodule Bandit.HTTP2.Adapter do
if remaining_length >= 0 do
do_read_req_body(adapter, timeout, remaining_length, acc)
else
bytes_read = IO.iodata_length(acc)

metrics =
adapter.metrics
|> Map.update(:req_body_bytes, bytes_read, &(&1 + bytes_read))

{:more, wrap_req_body(acc), %{adapter | metrics: metrics}}
return_more(acc, adapter)
end

:end_stream ->
bytes_read = IO.iodata_length(acc)

metrics =
adapter.metrics
|> Map.update(:req_body_bytes, bytes_read, &(&1 + bytes_read))
|> Map.put(:req_body_end_time, Bandit.Telemetry.monotonic_time())
pending_content_length =
case adapter.pending_content_length do
nil -> nil
pending_content_length -> pending_content_length - bytes_read
end

{:ok, wrap_req_body(acc), %{adapter | end_stream: true, metrics: metrics}}
if pending_content_length in [nil, 0] do
metrics =
adapter.metrics
|> Map.update(:req_body_bytes, bytes_read, &(&1 + bytes_read))
|> Map.put(:req_body_end_time, Bandit.Telemetry.monotonic_time())

{:ok, wrap_req_body(acc),
%{
adapter
| end_stream: true,
pending_content_length: pending_content_length,
metrics: metrics
}}
else
raise Bandit.HTTP2.Stream.StreamError,
message: "Received end of stream with #{pending_content_length} byte(s) pending",
method: adapter.method
end
after
timeout ->
bytes_read = IO.iodata_length(acc)
timeout -> return_more(acc, adapter)
end
end

metrics =
adapter.metrics
|> Map.update(:req_body_bytes, bytes_read, &(&1 + bytes_read))
defp return_more(data, adapter) do
bytes_read = IO.iodata_length(data)

{:more, wrap_req_body(acc), %{adapter | metrics: metrics}}
end
pending_content_length =
case adapter.pending_content_length do
nil -> nil
pending_content_length -> pending_content_length - bytes_read
end

metrics =
adapter.metrics
|> Map.update(:req_body_bytes, bytes_read, &(&1 + bytes_read))

{:more, wrap_req_body(data),
%{adapter | metrics: metrics, pending_content_length: pending_content_length}}
end

defp wrap_req_body(data) do
Expand Down
9 changes: 0 additions & 9 deletions lib/bandit/http2/connection.ex
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,6 @@ defmodule Bandit.HTTP2.Connection do
{:error, {:connection, error_code, error_message}} ->
shutdown_connection(error_code, error_message, socket, connection)

{:error, {:stream, stream_id, error_code, error_message}} ->
# If we're erroring out on a stream error, RFC9113§6.9 stipulates that we MUST take into
# account the sizes of errored frames. As such, ensure that we update our connection
# window to reflect that space taken up by this frame. We needn't worry about the stream's
# window since we're shutting it down anyway

connection = %{connection | recv_window_size: connection_recv_window_size}
handle_stream_error(stream_id, error_code, error_message, socket, connection)

{:error, error} ->
shutdown_connection(Errors.internal_error(), error, socket, connection)
end
Expand Down
45 changes: 6 additions & 39 deletions lib/bandit/http2/stream.ex
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ defmodule Bandit.HTTP2.Stream do
pid: nil,
recv_window_size: nil,
send_window_size: nil,
pending_content_length: nil,
span: nil

defmodule StreamError, do: defexception([:message, :method, :request_target, :status])
Expand All @@ -39,7 +38,6 @@ defmodule Bandit.HTTP2.Stream do
pid: pid() | nil,
recv_window_size: non_neg_integer(),
send_window_size: non_neg_integer(),
pending_content_length: non_neg_integer() | nil,
span: Bandit.Telemetry.t()
}

Expand Down Expand Up @@ -81,7 +79,6 @@ defmodule Bandit.HTTP2.Stream do
) do
with :ok <- stream_id_is_valid_client(stream.stream_id),
span <- start_span(connection_span, stream.stream_id),
{:ok, content_length} <- get_content_length(headers, stream.stream_id),
content_encoding <- negotiate_content_encoding(headers, opts),
req <-
Bandit.HTTP2.Adapter.init(
Expand All @@ -92,8 +89,7 @@ defmodule Bandit.HTTP2.Stream do
opts
),
{:ok, pid} <- StreamProcess.start_link(req, transport_info, headers, plug, span) do
{:ok,
%{stream | state: :open, pid: pid, pending_content_length: content_length, span: span}}
{:ok, %{stream | state: :open, pid: pid, span: span}}
else
:ignore -> {:error, "Unable to start stream process"}
other -> other
Expand Down Expand Up @@ -128,14 +124,6 @@ defmodule Bandit.HTTP2.Stream do
})
end

# RFC9113§8.1.1 - content length must be valid
defp get_content_length(headers, stream_id) do
case Bandit.Headers.get_content_length(headers) do
{:ok, content_length} -> {:ok, content_length}
{:error, reason} -> {:error, {:stream, stream_id, Errors.protocol_error(), reason}}
end
end

defp negotiate_content_encoding(headers, opts) do
Bandit.Compression.negotiate_content_encoding(
Bandit.Headers.get_header(headers, "accept-encoding"),
Expand All @@ -160,15 +148,7 @@ defmodule Bandit.HTTP2.Stream do
{new_window, increment} =
FlowControl.compute_recv_window(stream.recv_window_size, byte_size(data))

pending_content_length =
case stream.pending_content_length do
nil -> nil
pending_content_length -> pending_content_length - byte_size(data)
end

{:ok,
%{stream | recv_window_size: new_window, pending_content_length: pending_content_length},
increment}
{:ok, %{stream | recv_window_size: new_window}, increment}
end

def recv_data(%__MODULE__{} = stream, _data) do
Expand Down Expand Up @@ -205,17 +185,13 @@ defmodule Bandit.HTTP2.Stream do
@spec recv_end_of_stream(t(), boolean()) ::
{:ok, t()} | {:error, Connection.error()}
def recv_end_of_stream(%__MODULE__{state: :open} = stream, true) do
with :ok <- verify_content_length(stream) do
StreamProcess.recv_end_of_stream(stream.pid)
{:ok, %{stream | state: :remote_closed}}
end
StreamProcess.recv_end_of_stream(stream.pid)
{:ok, %{stream | state: :remote_closed}}
end

def recv_end_of_stream(%__MODULE__{state: :local_closed} = stream, true) do
with :ok <- verify_content_length(stream) do
StreamProcess.recv_end_of_stream(stream.pid)
{:ok, %{stream | state: :closed, pid: nil}}
end
StreamProcess.recv_end_of_stream(stream.pid)
{:ok, %{stream | state: :closed, pid: nil}}
end

def recv_end_of_stream(%__MODULE__{}, true) do
Expand All @@ -226,15 +202,6 @@ defmodule Bandit.HTTP2.Stream do
{:ok, stream}
end

defp verify_content_length(%__MODULE__{pending_content_length: nil}), do: :ok
defp verify_content_length(%__MODULE__{pending_content_length: 0}), do: :ok

defp verify_content_length(%__MODULE__{} = stream) do
{:error,
{:stream, stream.stream_id, Errors.protocol_error(),
"Received end of stream with #{stream.pending_content_length} byte(s) pending"}}
end

@spec owner?(t(), pid()) :: :ok | {:error, :not_owner}
def owner?(%__MODULE__{pid: pid}, pid), do: :ok
def owner?(%__MODULE__{}, _pid), do: {:error, :not_owner}
Expand Down
2 changes: 2 additions & 0 deletions lib/bandit/http2/stream_process.ex
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ defmodule Bandit.HTTP2.StreamProcess do
:ok <- headers_all_lowercase(headers),
:ok <- no_connection_headers(headers),
:ok <- valid_te_header(headers),
{:ok, content_length} <- Bandit.Headers.get_content_length(headers),
req <- %{req | pending_content_length: content_length},
headers <- combine_cookie_crumbs(headers),
req <- Bandit.HTTP2.Adapter.add_end_header_metric(req),
adapter <- {Bandit.HTTP2.Adapter, req},
Expand Down
2 changes: 0 additions & 2 deletions test/bandit/http2/protocol_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -848,8 +848,6 @@ defmodule HTTP2ProtocolTest do
]

SimpleH2Client.send_headers(socket, 1, false, headers)
SimpleH2Client.send_body(socket, 1, true, String.duplicate("a", 8_000))

assert SimpleH2Client.recv_rst_stream(socket) == {:ok, 1, 1}
end

Expand Down

0 comments on commit d28019a

Please sign in to comment.