diff --git a/lib/bandit/websocket/connection.ex b/lib/bandit/websocket/connection.ex index 8b07a626..628f21ed 100644 --- a/lib/bandit/websocket/connection.ex +++ b/lib/bandit/websocket/connection.ex @@ -64,12 +64,12 @@ defmodule Bandit.WebSocket.Connection do when not is_nil(fragment_frame) do case frame do %Frame.Continuation{fin: true} = frame -> - data = connection.fragment_frame.data <> frame.data + data = IO.iodata_to_binary([connection.fragment_frame.data | frame.data]) frame = %{connection.fragment_frame | fin: true, data: data} handle_frame(frame, socket, %{connection | fragment_frame: nil}) %Frame.Continuation{fin: false} = frame -> - data = connection.fragment_frame.data <> frame.data + data = [connection.fragment_frame.data | frame.data] frame = %{connection.fragment_frame | fin: true, data: data} {:continue, %{connection | fragment_frame: frame}} diff --git a/lib/bandit/websocket/frame.ex b/lib/bandit/websocket/frame.ex index a4938003..157bd01d 100644 --- a/lib/bandit/websocket/frame.ex +++ b/lib/bandit/websocket/frame.ex @@ -95,18 +95,39 @@ defmodule Bandit.WebSocket.Frame do defp mask_and_length(length) when length <= 65_535, do: <<0::1, 126::7, length::16>> defp mask_and_length(length), do: <<0::1, 127::7, length::64>> + # Masking is done @mask_size bits at a time until there is less than that number of bits left. + # We then go 32 bits at a time until there is less than 32 bits left. We then go 8 bits at + # a time. This yields some significant perforamnce gains for only marginally more complexity + @mask_size 512 + # Note that masking is an involution, so we don't need a separate unmask function - def mask(payload, mask, acc \\ <<>>) + def mask(payload, mask) when bit_size(payload) >= @mask_size do + payload + |> do_mask(String.duplicate(<>, div(@mask_size, 32)), []) + |> IO.iodata_to_binary() + end - def mask(payload, mask, acc) when is_integer(mask), do: mask(payload, <>, acc) + def mask(payload, mask) do + payload + |> do_mask(<>, []) + |> IO.iodata_to_binary() + end + + defp do_mask( + <>, + <> = mask, + acc + ) do + do_mask(rest, mask, [acc, <>]) + end - def mask(<>, <>, acc) do - mask(rest, mask, acc <> <>) + defp do_mask(<>, <> = mask, acc) do + do_mask(rest, mask, [acc, <>]) end - def mask(<>, <>, acc) do - mask(rest, <>, acc <> <>) + defp do_mask(<>, <>, acc) do + do_mask(rest, <>, [acc, <>]) end - def mask(<<>>, _mask, acc), do: acc + defp do_mask(<<>>, _mask, acc), do: acc end diff --git a/test/bandit/websocket/frame_deserialization_test.exs b/test/bandit/websocket/frame_deserialization_test.exs index 043e97af..f94c814e 100644 --- a/test/bandit/websocket/frame_deserialization_test.exs +++ b/test/bandit/websocket/frame_deserialization_test.exs @@ -20,6 +20,26 @@ defmodule WebSocketFrameDeserializationTest do end describe "frame size" do + test "parses 2 byte frames" do + payload = String.duplicate("a", 2) + masked_payload = Frame.mask(payload, 1234) + + frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 2::7, 1234::32, masked_payload::binary>> + + assert Frame.deserialize(frame) == + {{:ok, %Frame.Text{fin: true, compressed: false, data: payload}}, <<>>} + end + + test "parses 10 byte frames" do + payload = String.duplicate("a", 10) + masked_payload = Frame.mask(payload, 1234) + + frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 10::7, 1234::32, masked_payload::binary>> + + assert Frame.deserialize(frame) == + {{:ok, %Frame.Text{fin: true, compressed: false, data: payload}}, <<>>} + end + test "parses frames up to 125 bytes" do payload = String.duplicate("a", 125) masked_payload = Frame.mask(payload, 1234) diff --git a/test/support/simple_http1_client.ex b/test/support/simple_http1_client.ex index 910c7024..8b36aa77 100644 --- a/test/support/simple_http1_client.ex +++ b/test/support/simple_http1_client.ex @@ -2,7 +2,8 @@ defmodule SimpleHTTP1Client do @moduledoc false def tcp_client(context) do - {:ok, socket} = :gen_tcp.connect(~c"localhost", context[:port], active: false, mode: :binary) + {:ok, socket} = + :gen_tcp.connect(~c"localhost", context[:port], active: false, mode: :binary, nodelay: true) socket end @@ -12,6 +13,7 @@ defmodule SimpleHTTP1Client do :ssl.connect(~c"localhost", context[:port], active: false, mode: :binary, + nodelay: true, verify: :verify_peer, cacertfile: Path.join(__DIR__, "../support/ca.pem"), alpn_advertised_protocols: protocols