diff --git a/lib/bandit.ex b/lib/bandit.ex index 5760bbe6..c43daadd 100644 --- a/lib/bandit.ex +++ b/lib/bandit.ex @@ -20,7 +20,12 @@ defmodule Bandit do @typedoc """ Possible top-level options to configure a Bandit server - * `plug`: The Plug to use to handle connections. Can be specified as `MyPlug` or `{MyPlug, plug_opts}` + * `plug`: The Plug to use to handle connections. Can be specified as: + * `MyPlug` + * `{MyPlug, plug_opts}` + * `{&fun/2, plug_opts}` + * `&fun/2` + * `&fun/1` * `scheme`: One of `:http` or `:https`. If `:https` is specified, you will also need to specify valid `certfile` and `keyfile` values (or an equivalent value within `thousand_island_options.transport_options`). Defaults to `:http` @@ -219,8 +224,8 @@ defmodule Bandit do Keyword.get(arg, :websocket_options, []) |> validate_options(@websocket_keys, :websocket_options) - {plug_mod, _} = plug = plug(arg) - display_plug = Keyword.get(arg, :display_plug, plug_mod) + {plug_mod_or_fun, _} = plug = plug(arg) + display_plug = Keyword.get(arg, :display_plug, plug_mod_or_fun) startup_log = Keyword.get(arg, :startup_log, :info) {http_1_enabled, http_1_options} = Keyword.pop(http_1_options, :enabled, true) @@ -307,10 +312,11 @@ defmodule Bandit do |> Keyword.get(:plug) |> case do nil -> raise "A value is required for :plug" - {plug_fn, plug_options} when is_function(plug_fn, 2) -> {plug_fn, plug_options} - plug_fn when is_function(plug_fn) -> {plug_fn, []} - {plug, plug_options} when is_atom(plug) -> validate_plug(plug, plug_options) plug when is_atom(plug) -> validate_plug(plug, []) + {plug, plug_options} when is_atom(plug) -> validate_plug(plug, plug_options) + {plug_fn, plug_options} when is_function(plug_fn, 2) -> {plug_fn, plug_options} + plug_fn when is_function(plug_fn, 2) -> {plug_fn, []} + plug_fn when is_function(plug_fn, 1) -> {fn conn, [] -> plug_fn.(conn) end, []} other -> raise "Invalid value for plug: #{inspect(other)}" end end diff --git a/test/bandit/http1/request_test.exs b/test/bandit/http1/request_test.exs index eec81aef..3bea2cd8 100644 --- a/test/bandit/http1/request_test.exs +++ b/test/bandit/http1/request_test.exs @@ -11,7 +11,7 @@ defmodule HTTP1RequestTest do setup :req_http1_client describe "plug definitions" do - test "runs module plugs", context do + test "runs plug: module", context do response = Req.get!(context.req, url: "/hello_world") assert response.status == 200 assert response.body == "OK module" @@ -21,15 +21,31 @@ defmodule HTTP1RequestTest do send_resp(conn, 200, "OK module") end - test "runs function plugs", context do + test "runs plug: {&fun/2, options}", context do + context = + context + |> http_server(plug: {fn conn, string -> send_resp(conn, 200, string) end, "hello"}) + |> Enum.into(context) + + assert Req.get!(context.req, url: "/", base_url: context.base).body == "hello" + end + + test "runs plug: &fun/2", context do context = context |> http_server(plug: fn conn, _ -> send_resp(conn, 200, "OK function") end) |> Enum.into(context) - response = Req.get!(context.req, url: "/", base_url: context.base) - assert response.status == 200 - assert response.body == "OK function" + assert Req.get!(context.req, url: "/", base_url: context.base).body == "OK function" + end + + test "runs plug: &fun/1", context do + context = + context + |> http_server(plug: fn conn -> send_resp(conn, 200, "OK function") end) + |> Enum.into(context) + + assert Req.get!(context.req, url: "/", base_url: context.base).body == "OK function" end end diff --git a/test/bandit/http2/plug_test.exs b/test/bandit/http2/plug_test.exs index 25cf1ef6..c953908e 100644 --- a/test/bandit/http2/plug_test.exs +++ b/test/bandit/http2/plug_test.exs @@ -11,7 +11,7 @@ defmodule HTTP2PlugTest do setup :req_h2_client describe "plug definitions" do - test "runs module plugs", context do + test "runs plug: module", context do response = Req.get!(context.req, url: "/hello_world") assert response.status == 200 assert response.body == "OK module" @@ -21,7 +21,7 @@ defmodule HTTP2PlugTest do send_resp(conn, 200, "OK module") end - test "runs function plugs", context do + test "runs plug: &fun/2", context do context = context |> https_server(plug: fn conn, _ -> send_resp(conn, 200, "OK function") end)