diff --git a/src/cowboy_req.erl b/src/cowboy_req.erl index 8f0a04b52..cd6e159f9 100644 --- a/src/cowboy_req.erl +++ b/src/cowboy_req.erl @@ -67,6 +67,8 @@ -export([has_resp_header/2]). -export([has_resp_body/1]). -export([delete_resp_header/2]). +-export([set_cors_headers/2]). +-export([set_cors_preflight_headers/2]). -export([reply/2]). -export([reply/3]). -export([reply/4]). @@ -305,6 +307,8 @@ parse_header_fun(<<"accept">>) -> fun cow_http_hd:parse_accept/1; parse_header_fun(<<"accept-charset">>) -> fun cow_http_hd:parse_accept_charset/1; parse_header_fun(<<"accept-encoding">>) -> fun cow_http_hd:parse_accept_encoding/1; parse_header_fun(<<"accept-language">>) -> fun cow_http_hd:parse_accept_language/1; +parse_header_fun(<<"access-control-request-headers">>) -> fun cow_http_hd:parse_access_control_request_headers/1; +parse_header_fun(<<"access-control-request-method">>) -> fun cow_http_hd:parse_access_control_request_method/1; parse_header_fun(<<"authorization">>) -> fun cow_http_hd:parse_authorization/1; parse_header_fun(<<"connection">>) -> fun cow_http_hd:parse_connection/1; parse_header_fun(<<"content-length">>) -> fun cow_http_hd:parse_content_length/1; @@ -315,6 +319,7 @@ parse_header_fun(<<"if-match">>) -> fun cow_http_hd:parse_if_match/1; parse_header_fun(<<"if-modified-since">>) -> fun cow_http_hd:parse_if_modified_since/1; parse_header_fun(<<"if-none-match">>) -> fun cow_http_hd:parse_if_none_match/1; parse_header_fun(<<"if-unmodified-since">>) -> fun cow_http_hd:parse_if_unmodified_since/1; +parse_header_fun(<<"origin">>) -> fun cow_http_hd:parse_origin/1; parse_header_fun(<<"range">>) -> fun cow_http_hd:parse_range/1; parse_header_fun(<<"sec-websocket-extensions">>) -> fun cow_http_hd:parse_sec_websocket_extensions/1; parse_header_fun(<<"sec-websocket-protocol">>) -> fun cow_http_hd:parse_sec_websocket_protocol_req/1; @@ -666,6 +671,126 @@ delete_resp_header(Name, Req=#http_req{resp_headers=RespHeaders}) -> RespHeaders2 = lists:keydelete(Name, 1, RespHeaders), Req#http_req{resp_headers=RespHeaders2}. +-spec set_cors_headers(map(), Req) -> Req when Req :: req(). +set_cors_headers(M, Req) -> + try + AllowedOrigins = maps:get(origins, M, []), + Origin = + match_cors_origin( + %% Validating each origin in the list, picking up the first. + case parse_header(<<"origin">>, Req) of + undefined -> throw({bad_origin, undefined, AllowedOrigins}); + [H|T] -> _ = [match_cors_origin(Val, AllowedOrigins) || Val <- T], H; + L -> throw({bad_origin, L, AllowedOrigins}) + end, + AllowedOrigins), + + Req2 = set_cors_allow_credentials(maps:get(credentials, M, false), Origin, Req), + set_cors_exposed_headers(maps:get(exposed_headers, M, []), Req2) + catch throw:_Reason -> + Req + end. + +-spec set_cors_preflight_headers(map(), Req) -> Req when Req :: req(). +set_cors_preflight_headers(M, Req) -> + try + AllowedOrigins = maps:get(origins, M, []), + Origin = + match_cors_origin( + %% The Origin header can only contain a single origin as the user agent will not follow redirects. + case parse_header(<<"origin">>, Req) of + undefined -> throw({bad_origin, undefined, AllowedOrigins}); + [H] -> H; + L -> throw({bad_origin, L, AllowedOrigins}) + end, + AllowedOrigins), + Method = + match_cors_method( + parse_header(<<"access-control-request-method">>, Req), + maps:get(methods, M, [])), + Headers = + match_cors_headers( + parse_header(<<"access-control-request-headers">>, Req, []), + maps:get(headers, M, [])), + + Req2 = set_cors_allow_credentials(maps:get(credentials, M, false), Origin, Req), + Req3 = set_cors_max_age(maps:get(max_age, M, undefined), Req2), + Req4 = set_cors_allowed_methods([Method], Req3), + set_cors_allowed_headers(Headers, Req4) + catch throw:_Reason -> + Req + end. + +-spec set_cors_allow_credentials(boolean(), {binary(), binary(), 0..65535} | reference(), Req) -> Req when Req :: req(). +set_cors_allow_credentials(Credentials, Origin, Req) -> + case match_cors_credentials(Credentials, Origin) of + true -> + Req2 = set_resp_header(<<"access-control-allow-origin">>, cow_http_hd:access_control_allow_origin(Origin), Req), + set_resp_header(<<"access-control-allow-credentials">>, cow_http_hd:access_control_allow_credentials(), Req2); + _ -> + set_resp_header(<<"access-control-allow-origin">>, cow_http_hd:access_control_allow_origin(Origin), Req) + end. + +-spec set_cors_max_age(non_neg_integer() | undefined, Req) -> Req when Req :: req(). +set_cors_max_age(undefined, Req) -> + Req; +set_cors_max_age(Val, Req) -> + set_resp_header(<<"access-control-max-age">>, cow_http_hd:access_control_max_age(Val), Req). + +-spec set_cors_allowed_methods([binary()], Req) -> Req when Req :: req(). +set_cors_allowed_methods(L, Req) -> + set_resp_header(<<"access-control-allow-methods">>, cow_http_hd:access_control_allow_methods(L), Req). + +-spec set_cors_allowed_headers([binary()], Req) -> Req when Req :: req(). +set_cors_allowed_headers([], Req) -> + Req; +set_cors_allowed_headers(L, Req) -> + set_resp_header(<<"access-control-allow-headers">>, cow_http_hd:access_control_allow_headers(L), Req). + +-spec set_cors_exposed_headers([binary()], Req) -> Req when Req :: req(). +set_cors_exposed_headers([], Req) -> + Req; +set_cors_exposed_headers(L, Req) -> + set_resp_header(<<"access-control-expose-headers">>, cow_http_hd:access_control_expose_headers(L), Req). + +-spec match_cors_origin(Origin | reference(), [Origin] | Origin | '*') + -> Origin | '*' when Origin :: {binary(), binary(), 0..65535}. +match_cors_origin(Val, '*') when is_reference(Val) -> + '*'; +match_cors_origin(Val, '*') -> + Val; +match_cors_origin(Val, Val) -> + Val; +match_cors_origin(Val, AllowedOrigins) when is_list(AllowedOrigins) -> + case lists:member(Val, AllowedOrigins) of + true -> Val; + _ -> throw({nomatch_origin, Val, AllowedOrigins}) + end; +match_cors_origin(Val, AllowedOrigins) -> + throw({nomatch_origin, Val, AllowedOrigins}). + +-spec match_cors_method(binary() | undefined, [binary()]) -> binary(). +match_cors_method(undefined, Methods) -> + throw({bad_method, undefined, Methods}); +match_cors_method(Val, AllowedMethods) -> + case lists:member(Val, AllowedMethods) of + true -> Val; + _ -> throw({nomatch_method, Val, AllowedMethods}) + end. + +-spec match_cors_headers([binary()], [binary()]) -> [binary()]. +match_cors_headers(L, AllowedHeaders) -> + [case lists:member(Header, AllowedHeaders) of + false -> throw({nomatch_header, Header, AllowedHeaders}); + _ -> Header + end || Header <- L]. + +-spec match_cors_credentials(boolean(), {binary(), binary(), 0..65535} | reference() | '*') -> boolean(). +match_cors_credentials(true, '*') -> + throw({bad_credentials, true, '*'}); +match_cors_credentials(Val, _) -> + Val. + -spec reply(cowboy:http_status(), Req) -> Req when Req::req(). reply(Status, Req=#http_req{resp_body=Body}) -> reply(Status, [], Body, Req). diff --git a/test/cors_SUITE.erl b/test/cors_SUITE.erl new file mode 100644 index 000000000..9e3d04ccc --- /dev/null +++ b/test/cors_SUITE.erl @@ -0,0 +1,256 @@ +%% Copyright (c) 2016, Andrei Nesterov +%% +%% Permission to use, copy, modify, and/or distribute this software for any +%% purpose with or without fee is hereby granted, provided that the above +%% copyright notice and this permission notice appear in all copies. +%% +%% THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +%% WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +%% MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +%% ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +%% WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +%% ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +%% OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +-module(cors_SUITE). +-compile(export_all). + +-import(ct_helper, [config/2]). +-import(cowboy_test, [gun_open/1]). +-import(cowboy_test, [gun_open/2]). +-import(cowboy_test, [gun_down/1]). + +%% Definitions. +-define(ORIGIN_URI, <<"http://example.org">>). +-define(REQUEST_METHOD, <<"PUT">>). + +%% ct. + +all() -> + [ + {group, http}, + {group, https} + ]. + +groups() -> + Tests = ct_helper:all(?MODULE), + [ + {http, [parallel], Tests}, + {https, [parallel], Tests} + ]. + +init_per_group(Name = http, Config) -> + cowboy_test:init_http(Name, [ + {env, [{dispatch, init_dispatch(Config)}]} + ], Config); +init_per_group(Name = https, Config) -> + cowboy_test:init_https(Name, [ + {env, [{dispatch, init_dispatch(Config)}]} + ], Config). + +end_per_group(Name, _) -> + ok = cowboy:stop_listener(Name). + +%% Dispatch configuration. + +init_dispatch(_Config) -> + OriginsVal = {<<"http">>, <<"example.org">>, 80}, + OriginsAny = '*', + OriginsList = + [{<<"http">>, <<"example.com">>, 80}, + {<<"http">>, <<"example.org">>, 80}, + {<<"http">>, <<"example.org">>, 8080}], + Methods = [<<"GET">>, <<"PUT">>], + ExposedHeaders = Headers = [<<"h1">>, <<"h2">>, <<"h3">>], + MaxAge = 0, + + cowboy_router:compile([ + {"localhost", [ + {"/origins/val", cors_echo, + [{hs, #{origins => OriginsVal}}, + {phs, #{origins => OriginsVal, methods => Methods}}]}, + {"/origins/any", cors_echo, + [{hs, #{origins => OriginsAny}}, + {phs, #{origins => OriginsAny, methods => Methods}}]}, + {"/origins/list", cors_echo, + [{hs, #{origins => OriginsList}}, + {phs, #{origins => OriginsList, methods => Methods}}]}, + {"/credentials/false", cors_echo, + [{hs, #{origins => OriginsVal, credentials => false}}, + {phs, #{origins => OriginsVal, credentials => false, methods => Methods}}]}, + {"/credentials/true", cors_echo, + [{hs, #{origins => OriginsVal, credentials => true}}, + {phs, #{origins => OriginsVal, credentials => true, methods => Methods}}]}, + {"/credentials/true/origins/any", cors_echo, + [{hs, #{origins => OriginsAny, credentials => true}}, + {phs, #{origins => OriginsAny, credentials => true, methods => Methods}}]}, + {"/exposed_headers/undef", cors_echo, + [{hs, #{origins => OriginsVal}}, + {phs, #{origins => OriginsVal, methods => Methods}}]}, + {"/exposed_headers/list", cors_echo, + [{hs, #{origins => OriginsVal, exposed_headers => ExposedHeaders}}, + {phs, #{origins => OriginsVal, methods => Methods}}]}, + {"/max_age/undef", cors_echo, + [{hs, #{origins => OriginsVal}}, + {phs, #{origins => OriginsVal, methods => Methods}}]}, + {"/max_age/val", cors_echo, + [{hs, #{origins => OriginsVal}}, + {phs, #{origins => OriginsVal, max_age => MaxAge, methods => Methods}}]}, + {"/methods/list", cors_echo, + [{hs, #{origins => OriginsVal}}, + {phs, #{origins => OriginsVal, methods => Methods}}]}, + {"/headers/list", cors_echo, + [{hs, #{origins => OriginsVal}}, + {phs, #{origins => OriginsVal, headers => Headers, methods => Methods}}]} + ]} + ]). + +%% Convenience functions. + +do_request(Path, Headers, Config) -> + do_request(?REQUEST_METHOD, Path, Headers, Config). + +do_request(Method, Path, Headers, Config) -> + ConnPid = gun_open(Config), + Ref = gun:request(ConnPid, Method, Path, Headers), + {response, fin, 200, RespHeaders} = gun:await(ConnPid, Ref), + RespHeaders. + +do_preflight_request(Path, Headers, Config) -> + do_preflight_request(?REQUEST_METHOD, Path, Headers, Config). + +do_preflight_request(Method, Path, Headers, Config) -> + Headers2 = [{<<"access-control-request-method">>, Method}|Headers], + do_request(<<"OPTIONS">>, Path, Headers2, Config). + +do_find_header(Key, Headers) -> + case lists:keyfind(Key, 1, Headers) of + false -> error; + {_, Val} -> {ok, Val} + end. + +%% Tests. + +origins(Config) -> + OriginH = {<<"origin">>, ?ORIGIN_URI}, + OriginNoMatchH = {<<"origin">>, <<"null">>}, + Tests = + [%% Origin isn't presented + {"/origins/val", [], error}, + % Origin isn't allowed + {"/origins/val", [OriginNoMatchH], error}, + %% Single origin value is allowed + {"/origins/val", [OriginH], {ok, ?ORIGIN_URI}}, + %% Any origin is allowed + {"/origins/any", [OriginH], {ok, ?ORIGIN_URI}}, + %% Origin is presented in the allowed origins list + {"/origins/list", [OriginH], {ok, ?ORIGIN_URI}}, + %% Origin isn't presented in the allowed origins list + {"/origins/list", [OriginNoMatchH], error}], + + %% cors requests + [begin + RespHeaders = do_request(Path, Headers, Config), + MaybeOrigin = do_find_header(<<"access-control-allow-origin">>, RespHeaders) + end || {Path, Headers, MaybeOrigin} <- Tests], + + %% cors preflight requests + [begin + RespHeaders = do_preflight_request(Path, Headers, Config), + MaybeOrigin = do_find_header(<<"access-control-allow-origin">>, RespHeaders) + end || {Path, Headers, MaybeOrigin} <- Tests]. + +credentials(Config) -> + OriginH = {<<"origin">>, ?ORIGIN_URI}, + Tests = + [%% Credentials aren't supported + {"/credentials/false", [OriginH], error}, + %% Credentials are supported for this particular origin + {"/credentials/true", [OriginH], {ok, <<"true">>}}, + %% Credentials are supported for any origin + {"/credentials/true/origins/any", [OriginH], {ok, <<"true">>}}], + + %% cors requests + [begin + RespHeaders = do_request(Path, Headers, Config), + MaybeCredentials = do_find_header(<<"access-control-allow-credentials">>, RespHeaders), + {ok, ?ORIGIN_URI} = do_find_header(<<"access-control-allow-origin">>, RespHeaders) + end || {Path, Headers, MaybeCredentials} <- Tests], + + %% cors preflight requests + [begin + RespHeaders = do_preflight_request(Path, Headers, Config), + MaybeCredentials = do_find_header(<<"access-control-allow-credentials">>, RespHeaders), + {ok, ?ORIGIN_URI} = do_find_header(<<"access-control-allow-origin">>, RespHeaders) + end || {Path, Headers, MaybeCredentials} <- Tests]. + +exposed_headers(Config) -> + OriginH = {<<"origin">>, ?ORIGIN_URI}, + Tests = + [%% Exposed headers isn't set + {"/exposed_headers/undef", [OriginH], error}, + %% Exposed headers is set + {"/exposed_headers/list", [OriginH], {ok, <<"h1,h2,h3">>}}], + + %% cors requests + [begin + RespHeaders = do_request(Path, Headers, Config), + MaybeExposedHeaders = do_find_header(<<"access-control-expose-headers">>, RespHeaders), + {ok, ?ORIGIN_URI} = do_find_header(<<"access-control-allow-origin">>, RespHeaders) + end || {Path, Headers, MaybeExposedHeaders} <- Tests]. + +max_age(Config) -> + OriginH = {<<"origin">>, ?ORIGIN_URI}, + Tests = + [%% Max age isn't set + {"/max_age/undef", [OriginH], error}, + %% Max age is set + {"/max_age/val", [OriginH], {ok, <<"0">>}}], + + %% cors preflight requests + [begin + RespHeaders = do_preflight_request(Path, Headers, Config), + MaybeMaxAge = do_find_header(<<"access-control-max-age">>, RespHeaders), + {ok, ?ORIGIN_URI} = do_find_header(<<"access-control-allow-origin">>, RespHeaders) + end || {Path, Headers, MaybeMaxAge} <- Tests]. + +methods(Config) -> + OriginH = {<<"origin">>, ?ORIGIN_URI}, + MethodH = fun(Val) -> {<<"access-control-request-method">>, Val} end, + Tests = + [%% Method isn't presented + {"/methods/list", [OriginH], error, error}, + %% Method isn't allowed + {"/methods/list", [OriginH, MethodH(<<"PATCH">>)], error, error}, + %% Method is allowed + {"/methods/list", [OriginH, MethodH(?REQUEST_METHOD)], {ok, ?REQUEST_METHOD}, {ok, ?ORIGIN_URI}}], + + %% cors preflight requests + [begin + RespHeaders = do_request(<<"OPTIONS">>, Path, Headers, Config), + MaybeMethods = do_find_header(<<"access-control-allow-methods">>, RespHeaders), + MaybeOrigin = do_find_header(<<"access-control-allow-origin">>, RespHeaders) + end || {Path, Headers, MaybeMethods, MaybeOrigin} <- Tests]. + +headers(Config) -> + OriginH = {<<"origin">>, ?ORIGIN_URI}, + HeadersH = fun(Val) -> {<<"access-control-request-headers">>, Val} end, + Tests = + [%% Headers aren't presented + {"/headers/list", [OriginH], error, {ok, ?ORIGIN_URI}}, + %% Headers arent't allowed + {"/headers/list", [OriginH, HeadersH(<<"h8">>)], error, error}, + {"/headers/list", [OriginH, HeadersH(<<"h8,h9">>)], error, error}, + {"/headers/list", [OriginH, HeadersH(<<"h1,h9">>)], error, error}, + %% Headers are allowed + {"/headers/list", [OriginH, HeadersH(<<>>)], error, {ok, ?ORIGIN_URI}}, + {"/headers/list", [OriginH, HeadersH(<<"h1">>)], {ok, <<"h1">>}, {ok, ?ORIGIN_URI}}, + {"/headers/list", [OriginH, HeadersH(<<"h1,h2">>)], {ok, <<"h1,h2">>}, {ok, ?ORIGIN_URI}}], + + %% cors preflight requests + [begin + RespHeaders = do_preflight_request(Path, Headers, Config), + MaybeHeaders = do_find_header(<<"access-control-allow-headers">>, RespHeaders), + MaybeOrigin = do_find_header(<<"access-control-allow-origin">>, RespHeaders) + end || {Path, Headers, MaybeHeaders, MaybeOrigin} <- Tests]. + diff --git a/test/cors_SUITE_data/cors_echo.erl b/test/cors_SUITE_data/cors_echo.erl new file mode 100644 index 000000000..541c8c165 --- /dev/null +++ b/test/cors_SUITE_data/cors_echo.erl @@ -0,0 +1,16 @@ +%% Feel free to use, reuse and abuse the code in this file. + +-module(cors_echo). + +-export([init/2]). + +init(Req, Opts) -> + {_, Hs} = lists:keyfind(hs, 1, Opts), + {_, PHs} = lists:keyfind(phs, 1, Opts), + Req2 = + case cowboy_req:method(Req) of + <<"OPTIONS">> -> cowboy_req:set_cors_preflight_headers(PHs, Req); + _ -> cowboy_req:set_cors_headers(Hs, Req) + end, + {ok, cowboy_req:reply(200, Req2), Opts}. +