Skip to content

Commit

Permalink
Factor out WebSocket upgrade validation to its own module & refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mtrudel committed Sep 22, 2023
1 parent 8e8cc1e commit 708fa57
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 215 deletions.
2 changes: 1 addition & 1 deletion .credo.exs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
{Credo.Check.Refactor.FilterCount, []},
{Credo.Check.Refactor.FilterFilter, []},
{Credo.Check.Refactor.RejectReject, []},
{Credo.Check.Refactor.RedundantWithClauseResult, []},

#
## Warnings
Expand Down Expand Up @@ -125,6 +124,7 @@
{Credo.Check.Refactor.NegatedIsNil, []},
{Credo.Check.Refactor.PassAsyncInTestCases, []},
{Credo.Check.Refactor.PipeChainStart, []},
{Credo.Check.Refactor.RedundantWithClauseResult, []},
{Credo.Check.Refactor.RejectFilter, []},
{Credo.Check.Refactor.VariableRebinding, []},
{Credo.Check.Warning.LazyLogging, []},
Expand Down
18 changes: 9 additions & 9 deletions lib/bandit/websocket/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ The HTTP request containing the upgrade request is first passed to the user's
application as a standard Plug call. After inspecting the request and deeming it
a suitable upgrade candidate (via whatever policy the application dictates), the
user indicates a desire to upgrade the connection to a WebSocket by calling
`Plug.Conn.upgrade_adapter/3` (this is most commonly done by calling
`WebSockAdapter.upgrade/4`, which wraps this underlying call in
a server-agnostic manner). At the conclusion of the `Plug.call/2` callback,
`Bandit.Pipeline` will then attempy to upgrade the underlying connection. As
part of this upgrade process, `Bandit.DelegatingHandler` will switch the
Handler for the connection to be `Bandit.WebSocket.Handler`. This will cause any
future communication after the upgrade process to be handled directly by
Bandit's WebSocket stack.
`WebSockAdapter.upgrade/4`, which checks that the request is a valid WebSocket
upgrade request, and then calls `Plug.Conn.upgrade_adapter/3` to signal to
Bandit that the connection should be upgraded at the conclusion of the request.
At the conclusion of the `Plug.call/2` callback, `Bandit.Pipeline` will then
attempt to upgrade the underlying connection. As part of this upgrade process,
`Bandit.DelegatingHandler` will switch the Handler for the connection to be
`Bandit.WebSocket.Handler`. This will cause any future communication after the
upgrade process to be handled directly by Bandit's WebSocket stack.

## Process model

Expand All @@ -41,7 +41,7 @@ modeled by the `Bandit.WebSocket.Connection` struct and module.

All data subsequently received by the underlying [Thousand
Island](https://github.com/mtrudel/thousand_island) library will result in
a call to `Bandit.WebSocket.Handler.handle_data/3`, which will then attmept to
a call to `Bandit.WebSocket.Handler.handle_data/3`, which will then attempt to
parse the data into one or more WebSocket frames. Once a frame has been
constructed, it is them passed through to the configured `WebSock` handler by
way of the underlying `Bandit.WebSocket.Connection`.
Expand Down
44 changes: 1 addition & 43 deletions lib/bandit/websocket/handshake.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,10 @@ defmodule Bandit.WebSocket.Handshake do

@type extensions :: [{String.t(), [{String.t(), String.t() | true}]}]

@spec valid_upgrade?(Plug.Conn.t()) :: boolean()
def valid_upgrade?(%Plug.Conn{} = conn) do
validate_upgrade(conn) == :ok
end

@spec validate_upgrade(Plug.Conn.t()) :: :ok | {:error, String.t()}
defp validate_upgrade(conn) do
# Cases from RFC6455§4.2.1
with {:http_version, :"HTTP/1.1"} <- {:http_version, get_http_protocol(conn)},
{:method, "GET"} <- {:method, conn.method},
{:host_header, header} when header != [] <- {:host_header, get_req_header(conn, "host")},
{:upgrade_header, true} <-
{:upgrade_header, header_contains(conn, "upgrade", "websocket")},
{:connection_header, true} <-
{:connection_header, header_contains(conn, "connection", "upgrade")},
{:sec_websocket_key_header, true} <-
{:sec_websocket_key_header,
match?([<<_::binary>>], get_req_header(conn, "sec-websocket-key"))},
{:sec_websocket_version_header, ["13"]} <-
{:sec_websocket_version_header, get_req_header(conn, "sec-websocket-version")} do
:ok
else
{step, detail} ->
{:error, "WebSocket upgrade failed: error in #{step} check: #{inspect(detail)}"}
end
end

@spec handshake(Plug.Conn.t(), keyword(), keyword()) ::
{:ok, Plug.Conn.t(), Keyword.t()} | {:error, String.t()}
def handshake(%Plug.Conn{} = conn, opts, websocket_opts) do
with :ok <- validate_upgrade(conn) do
with :ok <- Bandit.WebSocket.UpgradeValidation.validate_upgrade(conn) do
do_handshake(conn, opts, websocket_opts)
end
end
Expand Down Expand Up @@ -126,19 +99,4 @@ defmodule Bandit.WebSocket.Handshake do

put_resp_header(conn, "sec-websocket-extensions", extensions)
end

@spec header_contains(Plug.Conn.t(), field :: String.t(), value :: String.t()) ::
true | binary()
defp header_contains(conn, field, value) do
downcase_value = String.downcase(value, :ascii)
header = get_req_header(conn, field)

header
|> Enum.flat_map(&Plug.Conn.Utils.list/1)
|> Enum.any?(&(String.downcase(&1, :ascii) == downcase_value))
|> case do
true -> true
false -> "Did not find '#{value}' in '#{header}'"
end
end
end
65 changes: 65 additions & 0 deletions lib/bandit/websocket/upgrade_validation.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
defmodule Bandit.WebSocket.UpgradeValidation do
@moduledoc false
# Provides validation of WebSocket upgrade requests as described in RFC6455§4.2

# Validates that the request satisfies the requirements to issue a WebSocket upgrade response.
# Validations are performed based on the clauses laid out in RFC6455§4.2
#
# This function does not actually perform an upgrade or change the connection in any way
#
# Returns `:ok` if the connection satisfies the requirements for a WebSocket upgrade, and
# `{:error, reason}` if not
#
@spec validate_upgrade(Plug.Conn.t()) :: :ok | {:error, String.t()}
def validate_upgrade(conn) do
case Plug.Conn.get_http_protocol(conn) do
:"HTTP/1.1" -> validate_upgrade_http1(conn)
other -> {:error, "HTTP version #{other} unsupported"}
end
end

# Validate the conn per RFC6455§4.2.1
defp validate_upgrade_http1(conn) do
with :ok <- assert_method(conn, "GET"),
:ok <- assert_header_nonempty(conn, "host"),
:ok <- assert_header_contains(conn, "connection", "upgrade"),
:ok <- assert_header_contains(conn, "upgrade", "websocket"),
:ok <- assert_header_nonempty(conn, "sec-websocket-key"),
:ok <- assert_header_equals(conn, "sec-websocket-version", "13") do
:ok
end
end

defp assert_method(conn, verb) do
case conn.method do
^verb -> :ok
other -> {:error, "HTTP method #{other} unsupported"}
end
end

defp assert_header_nonempty(conn, header) do
case Plug.Conn.get_req_header(conn, header) do
[] -> {:error, "'#{header}' header is absent"}
_ -> :ok
end
end

defp assert_header_equals(conn, header, expected) do
case Plug.Conn.get_req_header(conn, header) |> Enum.map(&String.downcase(&1, :ascii)) do
[^expected] -> :ok
value -> {:error, "'#{header}' header must equal '#{expected}', got #{inspect(value)}"}
end
end

defp assert_header_contains(conn, header, needle) do
haystack = Plug.Conn.get_req_header(conn, header)

haystack
|> Enum.flat_map(&Plug.Conn.Utils.list/1)
|> Enum.any?(&(String.downcase(&1, :ascii) == needle))
|> case do
true -> :ok
false -> {:error, "'#{header}' header must contain '#{needle}', got #{inspect(haystack)}"}
end
end
end
25 changes: 10 additions & 15 deletions test/bandit/http1/request_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -638,13 +638,12 @@ defmodule HTTP1RequestTest do
)

assert SimpleHTTP1Client.recv_reply(client)
~> {:ok, "400 Bad Request", list(),
"WebSocket upgrade failed: error in method check: \"POST\""}
~> {:ok, "400 Bad Request", list(), "HTTP method POST unsupported"}

Process.sleep(100)
end)

assert errors =~ "WebSocket upgrade failed: error in method check: \\\"POST\\\""
assert errors =~ "HTTP method POST unsupported"
end

test "returns a 400 and errors loudly in cases where an upgrade is indicated but upgrade header is incorrect",
Expand All @@ -668,13 +667,12 @@ defmodule HTTP1RequestTest do

assert SimpleHTTP1Client.recv_reply(client)
~> {:ok, "400 Bad Request", list(),
"WebSocket upgrade failed: error in upgrade_header check: \"Did not find 'websocket' in 'NOPE'\""}
"'upgrade' header must contain 'websocket', got [\"NOPE\"]"}

Process.sleep(100)
end)

assert errors =~
"WebSocket upgrade failed: error in upgrade_header check: \\\"Did not find 'websocket' in 'NOPE'\\\""
assert errors =~ "'upgrade' header must contain 'websocket', got [\\\"NOPE\\\"]"
end

test "returns a 400 and errors loudly in cases where an upgrade is indicated but connection header is incorrect",
Expand All @@ -698,13 +696,12 @@ defmodule HTTP1RequestTest do

assert SimpleHTTP1Client.recv_reply(client)
~> {:ok, "400 Bad Request", list(),
"WebSocket upgrade failed: error in connection_header check: \"Did not find 'upgrade' in 'NOPE'\""}
"'connection' header must contain 'upgrade', got [\"NOPE\"]"}

Process.sleep(100)
end)

assert errors =~
"WebSocket upgrade failed: error in connection_header check: \\\"Did not find 'upgrade' in 'NOPE'\\\""
assert errors =~ "'connection' header must contain 'upgrade', got [\\\"NOPE\\\"]"
end

test "returns a 400 and errors loudly in cases where an upgrade is indicated but key header is incorrect",
Expand All @@ -726,13 +723,12 @@ defmodule HTTP1RequestTest do
)

assert SimpleHTTP1Client.recv_reply(client)
~> {:ok, "400 Bad Request", list(),
"WebSocket upgrade failed: error in sec_websocket_key_header check: false"}
~> {:ok, "400 Bad Request", list(), "'sec-websocket-key' header is absent"}

Process.sleep(100)
end)

assert errors =~ "WebSocket upgrade failed: error in sec_websocket_key_header check: false"
assert errors =~ "'sec-websocket-key' header is absent"
end

test "returns a 400 and errors loudly in cases where an upgrade is indicated but version header is incorrect",
Expand All @@ -756,13 +752,12 @@ defmodule HTTP1RequestTest do

assert SimpleHTTP1Client.recv_reply(client)
~> {:ok, "400 Bad Request", list(),
"WebSocket upgrade failed: error in sec_websocket_version_header check: [\"99\"]"}
"'sec-websocket-version' header must equal '13', got [\"99\"]"}

Process.sleep(100)
end)

assert errors =~
"WebSocket upgrade failed: error in sec_websocket_version_header check: [\\\"99\\\"]"
assert errors =~ "'sec-websocket-version' header must equal '13', got [\\\"99\\\"]"
end

test "returns a 400 and errors loudly if websocket support is not enabled", context do
Expand Down
5 changes: 1 addition & 4 deletions test/bandit/websocket/autobahn_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ defmodule WebsocketAutobahnTest do

@impl Plug
def call(conn, _opts) do
case Bandit.WebSocket.Handshake.valid_upgrade?(conn) do
true -> Plug.Conn.upgrade_adapter(conn, :websocket, {EchoWebSock, :ok, compress: true})
false -> Plug.Conn.send_resp(conn, 204, <<>>)
end
Plug.Conn.upgrade_adapter(conn, :websocket, {EchoWebSock, :ok, compress: true})
end

@tag capture_log: true
Expand Down
Loading

0 comments on commit 708fa57

Please sign in to comment.