Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use protocol default port in the event that no port is provided in host #228

Merged
merged 1 commit into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 7 additions & 22 deletions lib/bandit/pipeline.ex
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ defmodule Bandit.Pipeline do
defp build_conn({mod, req}, transport_info, method, request_target, headers) do
with {:ok, scheme} <- determine_scheme(transport_info, request_target),
version <- mod.get_http_protocol(req),
{:ok, host, port} <-
determine_host_and_port(transport_info, version, request_target, headers),
{:ok, host, port} <- determine_host_and_port(scheme, version, request_target, headers),
{path, query} <- determine_path_and_query(request_target) do
uri = %URI{scheme: scheme, host: host, port: port, path: path, query: query}
%Bandit.TransportInfo{peername: {remote_ip, _port}} = transport_info
Expand All @@ -55,25 +54,20 @@ defmodule Bandit.Pipeline do
end

@spec determine_host_and_port(
Bandit.TransportInfo.t(),
scheme :: binary(),
version :: atom(),
request_target(),
Plug.Conn.headers()
) ::
{:ok, Plug.Conn.host(), Plug.Conn.port_number()} | {:error, String.t()}
defp determine_host_and_port(
%Bandit.TransportInfo{sockname: local_info},
version,
{_, nil, nil, _},
headers
) do
defp determine_host_and_port(scheme, version, {_, nil, nil, _}, headers) do
with host_header when is_binary(host_header) <- Bandit.Headers.get_header(headers, "host"),
{:ok, host, port} <- Bandit.Headers.parse_hostlike_header(host_header) do
{:ok, host, port || determine_local_port(local_info)}
{:ok, host, port || URI.default_port(scheme)}
else
nil ->
case version do
:"HTTP/1.0" -> {:ok, "", determine_local_port(local_info)}
:"HTTP/1.0" -> {:ok, "", URI.default_port(scheme)}
_ -> {:error, "No host header"}
end

Expand All @@ -82,17 +76,8 @@ defmodule Bandit.Pipeline do
end
end

defp determine_host_and_port(
%Bandit.TransportInfo{sockname: local_info},
_version,
{_, host, port, _},
_headers
),
do: {:ok, to_string(host), port || determine_local_port(local_info)}

@spec determine_local_port(ThousandIsland.Transport.socket_info()) :: Plug.Conn.port_number()
defp determine_local_port({family, _}) when family in [:local, :unspec, :undefined], do: 0
defp determine_local_port({_ip, port}), do: port
defp determine_host_and_port(scheme, _version, {_, host, port, _}, _headers),
do: {:ok, to_string(host), port || URI.default_port(scheme)}

@spec determine_path_and_query(request_target()) :: {String.t(), nil | String.t()}
defp determine_path_and_query({_, _, _, :*}), do: {"*", nil}
Expand Down
19 changes: 6 additions & 13 deletions test/bandit/http1/request_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,18 @@ defmodule HTTP1RequestTest do
assert {:ok, "400 Bad Request", _headers, _body} = SimpleHTTP1Client.recv_reply(client)
end

test "derives port from underlying transport if no port specified in host header", context do
test "derives port from schema default if no port specified in host header", context do
client = SimpleHTTP1Client.tcp_client(context)
SimpleHTTP1Client.send(client, "GET", "/echo_components", ["host: banana"])
assert {:ok, "200 OK", _headers, body} = SimpleHTTP1Client.recv_reply(client)
assert Jason.decode!(body)["port"] == context[:port]
assert Jason.decode!(body)["port"] == 80
end

test "derives port from underlying transport if no host header set in HTTP/1.0", context do
test "derives port from schema default if no host header set in HTTP/1.0", context do
client = SimpleHTTP1Client.tcp_client(context)
SimpleHTTP1Client.send(client, "GET", "/echo_components", [], "1.0")
assert {:ok, "200 OK", _headers, body} = SimpleHTTP1Client.recv_reply(client)
assert Jason.decode!(body)["port"] == context[:port]
assert Jason.decode!(body)["port"] == 80
end

test "sets path and query string properly when no query string is present", context do
Expand Down Expand Up @@ -193,13 +193,6 @@ defmodule HTTP1RequestTest do
end

describe "absolute-form request target (RFC9112§3.2.2)" do
test "derives scheme from underlying transport", context do
client = SimpleHTTP1Client.tcp_client(context)
SimpleHTTP1Client.send(client, "GET", "http://banana/echo_components")
assert {:ok, "200 OK", _headers, body} = SimpleHTTP1Client.recv_reply(client)
assert Jason.decode!(body)["scheme"] == "http"
end

@tag capture_log: true
test "uses request-line scheme even if it does not match the transport", context do
client = SimpleHTTP1Client.tcp_client(context)
Expand Down Expand Up @@ -252,11 +245,11 @@ defmodule HTTP1RequestTest do
assert Jason.decode!(body)["port"] == 1234
end

test "derives port from underlying transport if no port specified in the URI", context do
test "derives port from schema default if no port specified in the URI", context do
client = SimpleHTTP1Client.tcp_client(context)
SimpleHTTP1Client.send(client, "GET", "http://banana/echo_components", ["host: banana"])
assert {:ok, "200 OK", _headers, body} = SimpleHTTP1Client.recv_reply(client)
assert Jason.decode!(body)["port"] == context[:port]
assert Jason.decode!(body)["port"] == 80
end

test "sets path and query string properly when no query string is present", context do
Expand Down
8 changes: 4 additions & 4 deletions test/bandit/http2/protocol_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1982,7 +1982,7 @@ defmodule HTTP2ProtocolTest do
assert SimpleH2Client.recv_rst_stream(socket) == {:ok, 1, 1}
end

test "derives port from underlying transport if no port specified in host header", context do
test "derives port from schema default if no port specified in host header", context do
socket = SimpleH2Client.setup_connection(context)

headers = [
Expand All @@ -1996,7 +1996,7 @@ defmodule HTTP2ProtocolTest do

assert SimpleH2Client.successful_response?(socket, 1, false)
{:ok, 1, true, body} = SimpleH2Client.recv_body(socket)
assert Jason.decode!(body)["port"] == context.port
assert Jason.decode!(body)["port"] == 443
end

test "sets path and query string properly when no query string is present", context do
Expand Down Expand Up @@ -2218,7 +2218,7 @@ defmodule HTTP2ProtocolTest do
assert Jason.decode!(body)["port"] == 1234
end

test "derives port from underlying transport if no port specified in host header", context do
test "derives port from schema default if no port specified in host header", context do
socket = SimpleH2Client.setup_connection(context)

headers = [
Expand All @@ -2232,7 +2232,7 @@ defmodule HTTP2ProtocolTest do

assert SimpleH2Client.successful_response?(socket, 1, false)
{:ok, 1, true, body} = SimpleH2Client.recv_body(socket)
assert Jason.decode!(body)["port"] == context.port
assert Jason.decode!(body)["port"] == 443
end

test "sets path and query string properly when no query string is present", context do
Expand Down