diff --git a/big_tests/tests/sasl2_SUITE.erl b/big_tests/tests/sasl2_SUITE.erl index 8a604affc28..2cd7d15eb95 100644 --- a/big_tests/tests/sasl2_SUITE.erl +++ b/big_tests/tests/sasl2_SUITE.erl @@ -22,7 +22,8 @@ groups() -> {all_tests, [parallel], [ {group, basic}, - {group, scram} + {group, scram}, + {group, stream_management} ]}, {basic, [parallel], [ @@ -40,6 +41,14 @@ groups() -> authenticate_with_scram_bad_abort, authenticate_with_scram_bad_response, authenticate_with_scram + ]}, + {stream_management, [parallel], + [ + sm_failure_missing_previd_does_not_stop_sasl2, + sm_failure_invalid_h_does_not_stop_sasl2, + sm_failure_exceeding_h_does_not_stop_sasl2, + sm_failure_unknown_smid_does_not_stop_sasl2, + sm_is_bound_at_sasl2_success ]} ]. @@ -78,7 +87,9 @@ end_per_testcase(Name, Config) -> load_sasl_extensible(Config) -> HostType = domain_helper:host_type(), Config1 = dynamic_modules:save_modules(HostType, Config), - dynamic_modules:ensure_modules(HostType, [{mod_sasl2, #{}}]), + Modules = [{mod_sasl2, config_parser_helper:default_mod_config(mod_sasl2)}, + {mod_stream_management, config_parser_helper:mod_config(mod_stream_management, #{ack_freq => never})}], + dynamic_modules:ensure_modules(HostType, Modules), Config1. %%-------------------------------------------------------------------- @@ -100,7 +111,7 @@ server_announces_sasl2_with_some_mechanism(Config) -> ?assertNotEqual([], Mechs). authenticate_stanza_has_invalid_mechanism(Config) -> - Steps = [connect_tls_user, start_stream_get_features, send_invalid_authenticate_stanza], + Steps = [connect_tls_user, start_stream_get_features, send_invalid_mech_auth_stanza], #{answer := Response} = sasl2_helper:apply_steps(Steps, Config), ?assertMatch(#xmlel{name = <<"failure">>, attrs = [{<<"xmlns">>, ?NS_SASL_2}]}, Response). @@ -159,3 +170,43 @@ authenticate_again_results_in_stream_error(Config) -> plain_authentication, receive_features, plain_authentication], #{answer := Response} = sasl2_helper:apply_steps(Steps, Config), escalus:assert(is_stream_error, [<<"policy-violation">>, <<>>], Response). + +sm_failure_missing_previd_does_not_stop_sasl2(Config) -> + Steps = [buffer_messages_and_die, connect_tls_user, start_stream_get_features, + auth_with_resumption_missing_previd, receive_features], + #{answer := Success} = sasl2_helper:apply_steps(Steps, Config), + ?assertMatch(#xmlel{name = <<"success">>, attrs = [{<<"xmlns">>, ?NS_SASL_2}]}, Success), + Resumed = exml_query:path(Success, [{element_with_ns, <<"failed">>, ?NS_STREAM_MGNT_3}]), + escalus:assert(is_sm_failed, [<<"bad-request">>], Resumed). + +sm_failure_invalid_h_does_not_stop_sasl2(Config) -> + Steps = [buffer_messages_and_die, connect_tls_user, start_stream_get_features, + auth_with_resumption_invalid_h, receive_features], + #{answer := Success} = sasl2_helper:apply_steps(Steps, Config), + ?assertMatch(#xmlel{name = <<"success">>, attrs = [{<<"xmlns">>, ?NS_SASL_2}]}, Success), + Resumed = exml_query:path(Success, [{element_with_ns, <<"failed">>, ?NS_STREAM_MGNT_3}]), + escalus:assert(is_sm_failed, [<<"bad-request">>], Resumed). + +sm_failure_exceeding_h_does_not_stop_sasl2(Config) -> + Steps = [buffer_messages_and_die, connect_tls_user, start_stream_get_features, + auth_with_resumption_exceeding_h, receive_features], + #{answer := Success} = sasl2_helper:apply_steps(Steps, Config), + ?assertMatch(#xmlel{name = <<"success">>, attrs = [{<<"xmlns">>, ?NS_SASL_2}]}, Success), + Resumed = exml_query:path(Success, [{element_with_ns, <<"failed">>, ?NS_STREAM_MGNT_3}]), + escalus:assert(is_sm_failed, [<<"bad-request">>], Resumed). + +sm_failure_unknown_smid_does_not_stop_sasl2(Config) -> + Steps = [buffer_messages_and_die, connect_tls_user, start_stream_get_features, + auth_with_resumption_unknown_smid, receive_features], + #{answer := Success} = sasl2_helper:apply_steps(Steps, Config), + ?assertMatch(#xmlel{name = <<"success">>, attrs = [{<<"xmlns">>, ?NS_SASL_2}]}, Success), + Resumed = exml_query:path(Success, [{element_with_ns, <<"failed">>, ?NS_STREAM_MGNT_3}]), + escalus:assert(is_sm_failed, [<<"item-not-found">>], Resumed). + +sm_is_bound_at_sasl2_success(Config) -> + Steps = [buffer_messages_and_die, connect_tls_user, start_stream_get_features, + auth_with_resumption, has_no_more_stanzas], + #{answer := Success, smid := SMID} = sasl2_helper:apply_steps(Steps, Config), + ?assertMatch(#xmlel{name = <<"success">>, attrs = [{<<"xmlns">>, ?NS_SASL_2}]}, Success), + Resumed = exml_query:path(Success, [{element_with_ns, <<"resumed">>, ?NS_STREAM_MGNT_3}]), + ?assert(escalus_pred:is_sm_resumed(SMID, Resumed)). diff --git a/big_tests/tests/sasl2_helper.erl b/big_tests/tests/sasl2_helper.erl index eebbfa93105..11b450920c2 100644 --- a/big_tests/tests/sasl2_helper.erl +++ b/big_tests/tests/sasl2_helper.erl @@ -3,6 +3,7 @@ -include_lib("exml/include/exml.hrl"). -include_lib("escalus/include/escalus.hrl"). +-include_lib("escalus/include/escalus_xmlns.hrl"). -define(NS_SASL_2, <<"urn:xmpp:sasl:2">>). @@ -42,7 +43,7 @@ start_stream_get_features(_Config, Client, Data) -> Features = escalus_connection:get_stanza(Client1, wait_for_features), {Client1, Data#{features => Features}}. -send_invalid_authenticate_stanza(_Config, Client, Data) -> +send_invalid_mech_auth_stanza(_Config, Client, Data) -> Authenticate = #xmlel{name = <<"authenticate">>, attrs = [{<<"xmlns">>, ?NS_SASL_2}, {<<"mechanism">>, <<"invalid-non-existent-mechanism">>}]}, @@ -50,6 +51,14 @@ send_invalid_authenticate_stanza(_Config, Client, Data) -> Answer = escalus_client:wait_for_stanza(Client), {Client, Data#{answer => Answer}}. +send_invalid_ns_auth_stanza(_Config, Client, Data) -> + Authenticate = #xmlel{name = <<"authenticate">>, + attrs = [{<<"xmlns">>, <<"bad-namespace">>}, + {<<"mechanism">>, <<"PLAIN">>}]}, + escalus:send(Client, Authenticate), + Answer = escalus_client:wait_for_stanza(Client), + {Client, Data#{answer => Answer}}. + send_bad_user_agent(_Config, Client, Data) -> InitialResponse = initial_response_elem(<<"some-random-payload">>), Agent = bad_user_agent_elem(), @@ -58,6 +67,37 @@ send_bad_user_agent(_Config, Client, Data) -> Answer = escalus_client:wait_for_stanza(Client), {Client, Data#{answer => Answer}}. +auth_with_resumption(Config, Client, #{smid := SMID, texts := Texts} = Data) -> + Resume = escalus_stanza:resume(SMID, 1), + {Client1, Data1} = plain_auth(Config, Client, Data, [Resume]), + Msgs = sm_helper:wait_for_messages(Client, Texts), + {Client1, Data1#{sm_storage => Msgs}}. + +auth_with_resumption_invalid_h(Config, Client, #{smid := SMID} = Data) -> + Resume = #xmlel{name = <<"resume">>, + attrs = [{<<"xmlns">>, ?NS_STREAM_MGNT_3}, + {<<"previd">>, SMID}, + {<<"h">>, <<"aaa">>}]}, + plain_auth(Config, Client, Data, [Resume]). + +auth_with_resumption_missing_previd(Config, Client, Data) -> + Resume = #xmlel{name = <<"resume">>, + attrs = [{<<"xmlns">>, ?NS_STREAM_MGNT_3}, + {<<"h">>, <<"aaa">>}]}, + plain_auth(Config, Client, Data, [Resume]). + +auth_with_resumption_exceeding_h(Config, Client, #{smid := SMID} = Data) -> + Resume = escalus_stanza:resume(SMID, 999), + plain_auth(Config, Client, Data, [Resume]). + +auth_with_resumption_unknown_smid(Config, Client, Data) -> + Resume = escalus_stanza:resume(<<"123456">>, 1), + plain_auth(Config, Client, Data, [Resume]). + +has_no_more_stanzas(_Config, Client, Data) -> + escalus_assert:has_no_stanzas(Client), + {Client, Data}. + plain_auth_user_agent_without_id(Config, Client, Data) -> plain_auth(Config, Client, Data, [user_agent_elem_without_id()]). @@ -124,6 +164,22 @@ scram_step_2(_Config, Client, {error, _, _} -> throw({auth_failed, SuccessStanza}) end. +buffer_messages_and_die(Config, _Client, Data) -> + Spec = escalus_fresh:create_fresh_user(Config, alice), + Client = sm_helper:connect_spec(Spec, sr_presence, manual), + C2SPid = mongoose_helper:get_session_pid(Client), + BobSpec = escalus_fresh:create_fresh_user(Config, bob), + {ok, Bob, _} = escalus_connection:start(BobSpec), + Texts = [ integer_to_binary(N) || N <- [1, 2, 3]], + sm_helper:send_messages(Bob, Client, Texts), + %% Client receives them, but doesn't ack. + sm_helper:wait_for_messages(Client, Texts), + %% Client's connection is violently terminated. + escalus_client:kill_connection(Config, Client), + sm_SUITE:wait_until_resume_session(C2SPid), + SMID = sm_helper:client_to_smid(Client), + {C2SPid, Data#{spec => Spec, smid => SMID, smh => 3, texts => Texts}}. + receive_features(_Config, Client, Data) -> Features = escalus_client:wait_for_stanza(Client), {Client, Data#{features => Features}}. diff --git a/src/c2s/mongoose_c2s_acc.erl b/src/c2s/mongoose_c2s_acc.erl index d7deeb35a44..0de27a7aa81 100644 --- a/src/c2s/mongoose_c2s_acc.erl +++ b/src/c2s/mongoose_c2s_acc.erl @@ -13,6 +13,7 @@ %% - `route': mongoose_acc elements to trigger the whole `handle_route' pipeline. %% - `flush': mongoose_acc elements to trigger the `handle_flush` pipeline. %% - `socket_send': xml elements to send on the socket to the user. +%% - `socket_send_first': xml elements to send on the socket to the user, but appended first. -module(mongoose_c2s_acc). -export([new/0, new/1, @@ -21,7 +22,16 @@ to_acc/3, to_acc_many/2 ]). --type key() :: state_mod | actions | c2s_state | c2s_data | stop | hard_stop | route | flush | socket_send. +-type key() :: state_mod + | actions + | c2s_state + | c2s_data + | stop + | hard_stop + | route + | flush + | socket_send + | socket_send_first. -type pairs() :: [pair()]. -type pair() :: {state_mod, {module(), term()}} | {actions, gen_statem:action()} @@ -31,7 +41,8 @@ | {hard_stop, term() | {shutdown, atom()}} | {route, mongoose_acc:t()} | {flush, mongoose_acc:t()} - | {socket_send, exml:element()}. + | {socket_send, exml:element() | [exml:element()]} + | {socket_send_first, exml:element() | [exml:element()]}. -type t() :: #{ state_mod := #{module() => term()}, @@ -54,7 +65,7 @@ socket_send => [exml:element()] }. --export_type([t/0]). +-export_type([t/0, pairs/0]). %% -------------------------------------------------------- %% API @@ -122,7 +133,8 @@ from_mongoose_acc(Acc, Key) -> (mongoose_acc:t(), stop, atom() | {shutdown, atom()}) -> mongoose_acc:t(); (mongoose_acc:t(), route, mongoose_acc:t()) -> mongoose_acc:t(); (mongoose_acc:t(), flush, mongoose_acc:t()) -> mongoose_acc:t(); - (mongoose_acc:t(), socket_send, exml:element()) -> mongoose_acc:t(). + (mongoose_acc:t(), socket_send, exml:element() | [exml:element()]) -> mongoose_acc:t(); + (mongoose_acc:t(), socket_send_first, exml:element() | [exml:element()]) -> mongoose_acc:t(). to_acc(Acc, Key, NewValue) -> C2SAcc = mongoose_acc:get_statem_acc(Acc), C2SAcc1 = to_c2s_acc(C2SAcc, {Key, NewValue}), @@ -161,6 +173,10 @@ to_c2s_acc(C2SAcc = #{socket_send := Stanzas}, {socket_send, NewStanzas}) when i C2SAcc#{socket_send := lists:reverse(NewStanzas) ++ Stanzas}; to_c2s_acc(C2SAcc = #{socket_send := Stanzas}, {socket_send, Stanza}) -> C2SAcc#{socket_send := [Stanza | Stanzas]}; +to_c2s_acc(C2SAcc = #{socket_send := Stanzas}, {socket_send_first, NewStanzas}) when is_list(NewStanzas) -> + C2SAcc#{socket_send := Stanzas ++ NewStanzas}; +to_c2s_acc(C2SAcc = #{socket_send := Stanzas}, {socket_send_first, Stanza}) -> + C2SAcc#{socket_send := Stanzas ++ [Stanza]}; to_c2s_acc(C2SAcc = #{actions := Actions}, {stop, Reason}) -> C2SAcc#{actions := [{next_event, cast, {stop, Reason}} | Actions]}; to_c2s_acc(C2SAcc, {Key, NewValue}) -> diff --git a/src/mod_sasl2.erl b/src/mod_sasl2.erl index 2677464c146..5003ea8f904 100644 --- a/src/mod_sasl2.erl +++ b/src/mod_sasl2.erl @@ -18,21 +18,22 @@ -export([callback_mode/0, init/1, handle_event/4, terminate/3]). %% hooks handlers --export([c2s_stream_features/3, - user_send_xmlel/3, - sasl2_success/3, - sasl2_continue/3, - sasl2_failure/3, - sasl2_error/3]). +-export([c2s_stream_features/3, user_send_xmlel/3]). + +%% helpers +-export([get_inline_request/2, put_inline_request/3, update_inline_request/4]). -type maybe_binary() :: undefined | binary(). +-type status() :: pending | success | failure. +-type inline_request() :: #{request := exml:element(), + response := undefined | exml:element(), + status := status()}. -type mod_state() :: #{authenticated := boolean(), id := not_provided | uuid:uuid(), software := not_provided | binary(), device := not_provided | binary()}. --type params() :: #{c2s_data => mongoose_c2s:data(), c2s_state => mongoose_c2s:state()}. --export_type([params/0]). +-export_type([inline_request/0]). %% gen_mod -spec start(mongooseim:host_type(), gen_mod:module_opts()) -> ok. @@ -52,11 +53,7 @@ hooks(HostType) -> -spec c2s_hooks(mongooseim:host_type()) -> gen_hook:hook_list(mongoose_c2s_hooks:fn()). c2s_hooks(HostType) -> [ - {user_send_xmlel, HostType, fun ?MODULE:user_send_xmlel/3, #{}, 50}, - {sasl2_success, HostType, fun ?MODULE:sasl2_success/3, #{}, 50}, - {sasl2_continue, HostType, fun ?MODULE:sasl2_continue/3, #{}, 50}, - {sasl2_failure, HostType, fun ?MODULE:sasl2_failure/3, #{}, 50}, - {sasl2_error, HostType, fun ?MODULE:sasl2_error/3, #{}, 50} + {user_send_xmlel, HostType, fun ?MODULE:user_send_xmlel/3, #{}, 50} ]. -spec supported_features() -> [atom()]. @@ -75,12 +72,8 @@ init(_) -> gen_statem:event_handler_result(mongoose_c2s:state(), mongoose_c2s:data()). handle_event(internal, #xmlel{name = <<"authenticate">>} = El, {wait_for_feature_before_auth, SaslAcc, Retries} = C2SState, C2SData) -> - case exml_query:attr(El, <<"xmlns">>) of - ?NS_SASL_2 -> - handle_auth_start(C2SData, C2SState, El, SaslAcc, Retries); - _ -> - mongoose_c2s:c2s_stream_error(C2SData, mongoose_xmpp_errors:invalid_namespace()) - end; + %% We don't verify the namespace here because to here we just jumped from user_send_xmlel + handle_auth_start(C2SData, C2SState, El, SaslAcc, Retries); handle_event(internal, #xmlel{name = <<"response">>} = El, {wait_for_sasl_response, SaslAcc, Retries} = C2SState, C2SData) -> case exml_query:attr(El, <<"xmlns">>) of @@ -120,26 +113,6 @@ c2s_stream_features(Acc, #{c2s_data := C2SData}, _) -> {ok, lists:keystore(feature_name(), #xmlel.name, Acc, Sasl2Feature)} end. --spec sasl2_success(mongoose_acc:t(), mongoose_c2s_hooks:params(), gen_hook:extra()) -> - mongoose_c2s_hooks:result(). -sasl2_success(Acc, _Params, _Extra) -> - {ok, Acc}. - --spec sasl2_continue(mongoose_acc:t(), mongoose_c2s_hooks:params(), gen_hook:extra()) -> - mongoose_c2s_hooks:result(). -sasl2_continue(Acc, _Params, _Extra) -> - {ok, Acc}. - --spec sasl2_failure(mongoose_acc:t(), mongoose_c2s_hooks:params(), gen_hook:extra()) -> - mongoose_c2s_hooks:result(). -sasl2_failure(Acc, _Params, _Extra) -> - {ok, Acc}. - --spec sasl2_error(mongoose_acc:t(), mongoose_c2s_hooks:params(), gen_hook:extra()) -> - mongoose_c2s_hooks:result(). -sasl2_error(Acc, _Params, _Extra) -> - {ok, Acc}. - -spec user_send_xmlel(mongoose_acc:t(), mongoose_c2s_hooks:params(), gen_hook:extra()) -> mongoose_c2s_hooks:result(). user_send_xmlel(Acc, Params, _Extra) -> @@ -235,33 +208,21 @@ handle_sasl_step(HookParams, {continue, NewSaslAcc, Result}, Retries) -> handle_sasl_continue(HookParams, NewSaslAcc, Result, Retries); handle_sasl_step(HookParams, {failure, NewSaslAcc, Result}, Retries) -> handle_sasl_failure(HookParams, NewSaslAcc, Result, Retries); -handle_sasl_step(HookParams, {error, NewSaslAcc, Result}, Retries) -> - handle_sasl_error(HookParams, NewSaslAcc, Result, Retries). +handle_sasl_step(HookParams, {error, NewSaslAcc, #{type := Type}}, Retries) -> + handle_sasl_failure(HookParams, NewSaslAcc, #{server_out => atom_to_binary(Type), + maybe_username => undefined}, Retries). -spec handle_sasl_success( mongoose_c2s_hooks:params(), mongoose_acc:t(), mongoose_c2s_sasl:success()) -> mongoose_c2s:fsm_res(). -handle_sasl_success(#{c2s_data := C2SData, c2s_state := C2SState} = HookParams, SaslAcc, - #{server_out := MaybeServerOut, jid := Jid, auth_module := AuthMod}) -> - C2SData1 = mongoose_c2s:set_jid(C2SData, Jid), - C2SData2 = mongoose_c2s:set_auth_module(C2SData1, AuthMod), - HostType = mongoose_c2s:get_host_type(C2SData2), +handle_sasl_success(#{c2s_data := C2SData, c2s_state := C2SState} = HookParams, + SaslAcc, #{server_out := MaybeServerOut, jid := Jid, auth_module := AuthMod}) -> + C2SData1 = build_final_c2s_data(C2SData, Jid, AuthMod), ?LOG_INFO(#{what => auth_success, text => <<"Accepted SASL authentication">>, - user => jid:to_binary(Jid), c2s_state => C2SData2}), - ModState = get_mod_state(C2SData2), - Actions = [pop_callback_module, - {next_event, internal, mongoose_c2s_stanzas:stream_header(C2SData2)}, - mongoose_c2s:state_timeout(C2SData2)], - SuccessStanza = success_stanza(Jid, MaybeServerOut), - StreamFeaturesStanza = mongoose_c2s_stanzas:stream_features_after_auth(C2SData2), - C2SState1 = {wait_for_feature_after_auth, ?BIND_RETRIES}, - ToAcc = [{socket_send, [SuccessStanza, StreamFeaturesStanza]}, - {actions, Actions}, - {state_mod, {?MODULE, ModState#{authenticated := true}}}, - {c2s_state, C2SState1}], - SaslAcc1 = mongoose_c2s_acc:to_acc_many(SaslAcc, ToAcc), - SaslAcc2 = mongoose_hooks:sasl2_success(HostType, SaslAcc1, HookParams), - mongoose_c2s:handle_state_after_packet(C2SData2, C2SState, SaslAcc2). + user => jid:to_binary(Jid), c2s_state => C2SData1}), + HostType = mongoose_c2s:get_host_type(C2SData), + SaslAcc1 = mongoose_hooks:sasl2_success(HostType, SaslAcc, HookParams), + process_sasl2_success(SaslAcc1, C2SData1, C2SState, Jid, MaybeServerOut). -spec handle_sasl_continue( mongoose_c2s_hooks:params(), mongoose_acc:t(), mongoose_c2s_sasl:continue(), mongoose_c2s:retries()) -> @@ -297,25 +258,67 @@ handle_sasl_failure(#{c2s_data := C2SData, c2s_state := C2SState} = HookParams, mongoose_c2s:handle_state_after_packet(C2SData, C2SState1, SaslAcc2) end. --spec handle_sasl_error( - mongoose_c2s_hooks:params(), mongoose_acc:t(), mongoose_c2s_sasl:error(), mongoose_c2s:retries()) -> +%% Append to the c2s_data both the new jid and the auth module. +%% Note that further inline requests can later on append a new jid if a resource is negotiated. +-spec build_final_c2s_data(mongoose_c2s:data(), jid:jid(), module()) -> mongoose_c2s:data(). +build_final_c2s_data(C2SData, Jid, AuthMod) -> + C2SData1 = mongoose_c2s:set_jid(C2SData, Jid), + mongoose_c2s:set_auth_module(C2SData1, AuthMod). + +-spec process_sasl2_success(mongoose_acc:t(), mongoose_c2s:data(), mongoose_c2s:state(), jid:jid(), maybe_binary()) -> mongoose_c2s:fsm_res(). -handle_sasl_error(#{c2s_data := C2SData} = HookParams, SaslAcc, #{type := Type, text := Text}, _Retries) -> - Lang = mongoose_c2s:get_lang(C2SData), - El = mongoose_xmpp_errors:Type(Lang, Text), - mongoose_c2s:c2s_stream_error(C2SData, El), - HostType = mongoose_c2s:get_host_type(C2SData), - SaslAcc1 = mongoose_hooks:sasl2_error(HostType, SaslAcc, HookParams), - {stop, mongoose_c2s_acc:to_acc(SaslAcc1, hard_stop, sasl2_violation)}. - --spec success_stanza(jid:jid(), maybe_binary()) -> exml:element(). -success_stanza(AuthJid, undefined) -> - AuthorizationId = success_subelement(<<"authorization-identifier">>, jid:to_binary(AuthJid)), - sasl2_ns_stanza(<<"success">>, [AuthorizationId]); -success_stanza(AuthJid, CData) -> - AdditionalData = success_subelement(<<"additional-data">>, base64:encode(CData)), - AuthorizationId = success_subelement(<<"authorization-identifier">>, jid:to_binary(AuthJid)), - sasl2_ns_stanza(<<"success">>, [AdditionalData, AuthorizationId]). +process_sasl2_success(SaslAcc, C2SData, C2SState, Jid, MaybeServerOut) -> + SuccessStanza = success_stanza(SaslAcc, Jid, MaybeServerOut), + ToAcc = build_to_acc(SaslAcc, C2SData, SuccessStanza), + SaslAcc1 = mongoose_c2s_acc:to_acc_many(SaslAcc, ToAcc), + mongoose_c2s:handle_state_after_packet(C2SData, C2SState, SaslAcc1). + +%% After auth and inline requests we: +%% - return control to mongoose_c2s (pop_callback_module), +%% - ensure the answer to the sasl2 request is sent in the socket first, +%% - then decide depending on whether an inline request has taken control of the c2s_state if +%% - do nothing if control was taken +%% - put the statem in wait_for_feature_after_auth +-spec build_to_acc(mongoose_acc:t(), mongoose_c2s:data(), exml:element()) -> mongoose_c2s_acc:pairs(). +build_to_acc(SaslAcc, C2SData, SuccessStanza) -> + ModState = get_mod_state(C2SData), + ToAcc0 = [{actions, [pop_callback_module, mongoose_c2s:state_timeout(C2SData)]}, + {state_mod, {?MODULE, ModState#{authenticated := true}}}], + case is_new_c2s_state_requested(SaslAcc) of + true -> + [{socket_send_first, SuccessStanza} | ToAcc0]; + false -> + %% Unless specified by an inline feature, sasl2 would normally put the statem just before bind + StreamFeaturesStanza = mongoose_c2s_stanzas:stream_features_after_auth(C2SData), + [{socket_send_first, SuccessStanza}, + {socket_send, StreamFeaturesStanza}, + {c2s_state, {wait_for_feature_after_auth, ?BIND_RETRIES}} | ToAcc0] + end. + +-spec is_new_c2s_state_requested(mongoose_acc:t()) -> boolean(). +is_new_c2s_state_requested(SaslAcc) -> + #{c2s_state := NewState} = mongoose_c2s_acc:get_statem_result(SaslAcc), + undefined =/= NewState. + +-spec success_stanza(mongoose_acc:t(), jid:jid(), maybe_binary()) -> exml:element(). +success_stanza(SaslAcc, AuthJid, MaybeCData) -> + Inlines = get_inline_requests(SaslAcc), + InlineAnswers = get_inline_responses(Inlines), + case MaybeCData of + undefined -> + AuthorizationId = success_subelement(<<"authorization-identifier">>, jid:to_binary(AuthJid)), + sasl2_ns_stanza(<<"success">>, [AuthorizationId | InlineAnswers]); + CData -> + AdditionalData = success_subelement(<<"additional-data">>, base64:encode(CData)), + AuthorizationId = success_subelement(<<"authorization-identifier">>, jid:to_binary(AuthJid)), + sasl2_ns_stanza(<<"success">>, [AdditionalData, AuthorizationId | Inlines]) + end. + +-spec get_inline_responses([inline_request()]) -> [exml:element()]. +get_inline_responses(Inlines) -> + [ Response || {_Module, #{status := Status, response := Response}} <- Inlines, + pending =/= Status, + undefined =/= Response ]. -spec challenge_stanza(binary()) -> exml:element(). challenge_stanza(ServerOut) -> @@ -390,3 +393,26 @@ inlines(InlineFeatures) -> -spec feature_name() -> binary(). feature_name() -> <<"authentication">>. + +-spec get_inline_requests(mongoose_acc:t()) -> [{module(), inline_request()}]. +get_inline_requests(SaslAcc) -> + mongoose_acc:get(?MODULE, SaslAcc). + +-spec get_inline_request(mongoose_acc:t(), module()) -> undefined | inline_request(). +get_inline_request(SaslAcc, ModuleRequest) -> + mongoose_acc:get(?MODULE, ModuleRequest, undefined, SaslAcc). + +-spec put_inline_request(mongoose_acc:t(), module(), exml:element()) -> mongoose_acc:t(). +put_inline_request(SaslAcc, ModuleRequest, XmlRequest) -> + Request = #{request => XmlRequest, response => undefined, status => pending}, + mongoose_acc:set(?MODULE, ModuleRequest, Request, SaslAcc). + +-spec update_inline_request(mongoose_acc:t(), module(), exml:element(), status()) -> mongoose_acc:t(). +update_inline_request(SaslAcc, ModuleRequest, XmlResponse, Status) -> + case mongoose_acc:get(?MODULE, ModuleRequest, undefined, SaslAcc) of + undefined -> + SaslAcc; + Request -> + Request1 = Request#{response := XmlResponse, status := Status}, + mongoose_acc:set(?MODULE, ModuleRequest, Request1, SaslAcc) + end. diff --git a/src/stream_management/mod_stream_management.erl b/src/stream_management/mod_stream_management.erl index c92a973ddc6..54761797931 100644 --- a/src/stream_management/mod_stream_management.erl +++ b/src/stream_management/mod_stream_management.erl @@ -14,6 +14,9 @@ %% hooks handlers -export([c2s_stream_features/3, + sasl2_stream_features/3, + sasl2_start/3, + sasl2_success/3, session_cleanup/3, user_send_packet/3, user_receive_packet/3, @@ -59,6 +62,12 @@ -type buffer_max() :: pos_integer() | infinity | no_buffer. -type ack_freq() :: pos_integer() | never. +-type resume_return() :: {ok, #{resumed := exml:element(), + forward := [exml:element()]}, mongoose_acc:t()} + | {stream_mgmt_error, exml:element()} + | {error, exml:element()} + | {error, exml:element(), term()}. + %% Type base64:ascii_binary() is not exported -type smid() :: binary(). -type short() :: 0..?STREAM_MGMT_H_MAX. @@ -87,7 +96,10 @@ stop(HostType) -> -spec hooks(mongooseim:host_type()) -> gen_hook:hook_list(). hooks(HostType) -> [{c2s_stream_features, HostType, fun ?MODULE:c2s_stream_features/3, #{}, 50}, - {session_cleanup, HostType, fun ?MODULE:session_cleanup/3, #{}, 50} + {session_cleanup, HostType, fun ?MODULE:session_cleanup/3, #{}, 50}, + {sasl2_stream_features, HostType, fun ?MODULE:sasl2_stream_features/3, #{}, 50}, + {sasl2_start, HostType, fun ?MODULE:sasl2_start/3, #{}, 50}, + {sasl2_success, HostType, fun ?MODULE:sasl2_success/3, #{}, 30} | c2s_hooks(HostType) ]. -spec c2s_hooks(mongooseim:host_type()) -> gen_hook:hook_list(mongoose_c2s_hooks:fn()). @@ -401,7 +413,7 @@ handle_stream_mgmt(Acc, Params = #{c2s_state := C2SState}, El = #xmlel{name = << when ?IS_ALLOWED_STATE(C2SState) -> handle_enable(Acc, Params, El); handle_stream_mgmt(Acc, Params = #{c2s_state := {wait_for_feature_after_auth, _}}, El = #xmlel{name = <<"resume">>}) -> - handle_resume(Acc, Params, El); + handle_resume_request(Acc, Params, El); handle_stream_mgmt(Acc, #{c2s_data := StateData, c2s_state := C2SState}, _El) -> unexpected_sm_request(Acc, StateData, C2SState); handle_stream_mgmt(Acc, _Params, _El) -> @@ -509,23 +521,39 @@ do_handle_enable(Acc, StateData, true) -> ToAcc = [{state_mod, {?MODULE, SmState}}, {socket_send, Stanza}], {stop, mongoose_c2s_acc:to_acc_many(Acc, ToAcc)}. --spec handle_resume(mongoose_acc:t(), mongoose_c2s_hooks:params(), exml:element()) -> +-spec handle_resume_request(mongoose_acc:t(), mongoose_c2s_hooks:params(), exml:element()) -> mongoose_c2s_hooks:result(). -handle_resume(Acc, #{c2s_state := C2SState, c2s_data := StateData}, El) -> - case {get_previd(El), stream_mgmt_parse_h(El), get_mod_state(StateData)} of +handle_resume_request(Acc, #{c2s_state := C2SState, c2s_data := C2SData}, El) -> + case handle_resume(Acc, C2SData, C2SState, El) of + {stream_mgmt_error, ErrorStanza} -> + stream_mgmt_error(Acc, C2SData, C2SState, ErrorStanza); + {error, ErrorStanza, Reason} -> + mongoose_c2s:c2s_stream_error(C2SData, ErrorStanza), + {stop, mongoose_c2s_acc:to_acc(Acc, stop, Reason)}; + {error, ErrorStanza} -> + {stop, mongoose_c2s_acc:to_acc(Acc, socket_send, ErrorStanza)}; + {ok, #{resumed := Resumed, forward := ToForward}, Acc1} -> + ToAcc = [{socket_send, [Resumed | ToForward]}], + {ok, mongoose_c2s_acc:to_acc_many(Acc1, ToAcc)} + end. + +%% This runs on the new process +-spec handle_resume(mongoose_acc:t(), mongoose_c2s:data(), mongoose_c2s:state(), exml:element()) -> + resume_return(). +handle_resume(Acc, C2SData, C2SState, El) -> + case {get_previd(El), stream_mgmt_parse_h(El), get_mod_state(C2SData)} of {undefined, _, _} -> - bad_request(Acc); + {error, stream_mgmt_failed(<<"bad-request">>)}; {_, invalid_h_attribute, _} -> - bad_request(Acc); + {error, stream_mgmt_failed(<<"bad-request">>)}; {_, _, #sm_state{}} -> - bad_request(Acc); + {error, stream_mgmt_failed(<<"bad-request">>)}; {SMID, H, {error, not_found}} -> - HostType = mongoose_c2s:get_host_type(StateData), + HostType = mongoose_c2s:get_host_type(C2SData), FromSMID = get_session_from_smid(HostType, SMID), - do_handle_resume(Acc, StateData, C2SState, SMID, H, FromSMID) + do_handle_resume(Acc, C2SData, C2SState, SMID, H, FromSMID) end. -%% This runs on the new process -spec do_handle_resume(Acc, StateData, C2SState, SMID, H, FromSMID) -> HookResult when Acc :: mongoose_acc:t(), StateData :: mongoose_c2s:data(), @@ -533,52 +561,48 @@ handle_resume(Acc, #{c2s_state := C2SState, c2s_data := StateData}, El) -> SMID :: smid(), H :: non_neg_integer(), FromSMID :: maybe_smid(), - HookResult :: mongoose_c2s_hooks:result(). + HookResult :: resume_return(). do_handle_resume(Acc, StateData, _C2SState, SMID, H, {sid, {_TS, Pid}}) -> case get_peer_state(Pid, H) of {ok, OldStateData} -> NewState = mongoose_c2s:merge_states(OldStateData, StateData), do_resume(Acc, NewState, SMID); {error, ErrorStanza, Reason} -> - mongoose_c2s:c2s_stream_error(StateData, ErrorStanza), - {stop, mongoose_c2s_acc:to_acc(Acc, stop, Reason)}; + {error, ErrorStanza, Reason}; {exception, {C, R, S}} -> ?LOG_WARNING(#{what => resumption_error, text => <<"Resumption error because of invalid response">>, class => C, reason => R, stacktrace => S, pid => Pid, c2s_state => StateData}), - Err = stream_mgmt_failed(<<"item-not-found">>), - {stop, mongoose_c2s_acc:to_acc(Acc, socket_send, Err)} + {stream_mgmt_error, stream_mgmt_failed(<<"item-not-found">>)} end; -do_handle_resume(Acc, StateData, C2SState, SMID, _H, {stale_h, StaleH}) -> +do_handle_resume(_Acc, StateData, _C2SState, SMID, _H, {stale_h, StaleH}) -> ?LOG_WARNING(#{what => resumption_error, reason => session_resumption_timed_out, smid => SMID, stale_h => StaleH, c2s_state => StateData}), - Err = stream_mgmt_failed(<<"item-not-found">>, [{<<"h">>, integer_to_binary(StaleH)}]), - stream_mgmt_error(Acc, StateData, C2SState, Err); -do_handle_resume(Acc, StateData, C2SState, SMID, _H, {error, smid_not_found}) -> + {stream_mgmt_error, stream_mgmt_failed(<<"item-not-found">>, [{<<"h">>, integer_to_binary(StaleH)}])}; +do_handle_resume(_Acc, StateData, _C2SState, SMID, _H, {error, smid_not_found}) -> ?LOG_WARNING(#{what => resumption_error, reason => no_previous_session_for_smid, smid => SMID, c2s_state => StateData}), - Err = stream_mgmt_failed(<<"item-not-found">>), - stream_mgmt_error(Acc, StateData, C2SState, Err). + {stream_mgmt_error, stream_mgmt_failed(<<"item-not-found">>)}. %% This runs on the new process -spec do_resume(Acc, StateData, SMID) -> HookResult when Acc :: mongoose_acc:t(), StateData :: mongoose_c2s:data(), SMID :: smid(), - HookResult :: mongoose_c2s_hooks:result(). + HookResult :: resume_return(). do_resume(Acc, StateData, SMID) -> {_ReplacedPids, StateData2} = mongoose_c2s:open_session(StateData), ok = register_smid(StateData2, SMID), - Stanzas = get_all_stanzas_to_forward(StateData2, SMID), - ToAcc = [{c2s_state, session_established}, {c2s_data, StateData2}, {socket_send, Stanzas}], - {ok, mongoose_c2s_acc:to_acc_many(Acc, ToAcc)}. + {Resumed, ToForward} = get_all_stanzas_to_forward(StateData2, SMID), + ToAcc = [{c2s_state, session_established}, {c2s_data, StateData2}], + {ok, #{resumed => Resumed, forward => ToForward}, mongoose_c2s_acc:to_acc_many(Acc, ToAcc)}. register_smid(StateData, SMID) -> Sid = mongoose_c2s:get_sid(StateData), HostType = mongoose_c2s:get_host_type(StateData), ok = register_smid(HostType, SMID, Sid). --spec get_all_stanzas_to_forward(mongoose_c2s:data(), smid()) -> [exml:element()]. +-spec get_all_stanzas_to_forward(mongoose_c2s:data(), smid()) -> {exml:element(), [exml:element()]}. get_all_stanzas_to_forward(StateData, SMID) -> #sm_state{counter_in = Counter, buffer = Buffer} = get_mod_state(StateData), Resumed = stream_mgmt_resumed(SMID, Counter), @@ -591,7 +615,7 @@ get_all_stanzas_to_forward(StateData, SMID) -> StanzaType = mongoose_acc:stanza_type(Acc), maybe_add_timestamp(Packet, StanzaName, StanzaType, TS, FromServer) end || Acc <- lists:reverse(Buffer)], - [Resumed | ToForward]. + {Resumed, ToForward}. maybe_add_timestamp(Packet, <<"message">>, <<"error">>, _, _) -> Packet; @@ -644,12 +668,6 @@ is_conflict_receiver_sid(Acc, StateData) -> AccSid = mongoose_acc:get(c2s, receiver_sid, StateSid, Acc), StateSid =/= AccSid. --spec bad_request(mongoose_acc:t()) -> - {stop, mongoose_acc:t()}. -bad_request(Acc) -> - Err = stream_mgmt_failed(<<"bad-request">>), - {stop, mongoose_c2s_acc:to_acc(Acc, socket_send, Err)}. - -spec stream_error(mongoose_acc:t(), mongoose_c2s:data()) -> {stop, mongoose_acc:t()}. stream_error(Acc, StateData) -> @@ -715,6 +733,49 @@ get_previd(El) -> c2s_stream_features(Acc, _, _) -> {ok, lists:keystore(<<"sm">>, #xmlel.name, Acc, sm())}. +-spec sasl2_stream_features(Acc, #{c2s_data := mongoose_c2s:data()}, gen_hook:extra()) -> + {ok, Acc} when Acc :: [exml:element()]. +sasl2_stream_features(Acc, _, _) -> + Resume = #xmlel{name = <<"resume">>, attrs = [{<<"xmlns">>, ?NS_STREAM_MGNT_3}]}, + {ok, [Resume | Acc]}. + +-spec sasl2_start(SaslAcc, mongoose_c2s_hooks:params(), gen_hook:extra()) -> + {ok, SaslAcc} when SaslAcc :: mongoose_acc:t(). +sasl2_start(SaslAcc, #{event_content := #{stanza := El}}, _) -> + case exml_query:path(El, [{element_with_ns, <<"resume">>, ?NS_STREAM_MGNT_3}]) of + undefined -> + {ok, SaslAcc}; + SmRequest -> + {ok, mod_sasl2:put_inline_request(SaslAcc, ?MODULE, SmRequest)} + end. + +-spec sasl2_success(SaslAcc, mongoose_c2s_hooks:params(), gen_hook:extra()) -> + {ok, SaslAcc} when SaslAcc :: mongoose_acc:t(). +sasl2_success(SaslAcc, Params, _) -> + case mod_sasl2:get_inline_request(SaslAcc, ?MODULE) of + undefined -> + {ok, SaslAcc}; + SmRequest -> + handle_sasl2_resume(SaslAcc, Params, SmRequest) + end. + +-spec handle_sasl2_resume(SaslAcc, mongoose_c2s_hooks:params(), mod_sasl2:inline_request()) -> + {ok, SaslAcc} when SaslAcc :: mongoose_acc:t(). +handle_sasl2_resume(SaslAcc, #{c2s_state := C2SState, c2s_data := C2SData}, + #{request := El}) -> + case handle_resume(SaslAcc, C2SData, C2SState, El) of + {stream_mgmt_error, ErrorStanza} -> + {ok, mod_sasl2:update_inline_request(SaslAcc, ?MODULE, ErrorStanza, failure)}; + {error, _ErrorStanza, _Reason} -> %% This signifies a stream-error, but we discard those here + SimpleErrorStanza = stream_mgmt_failed(<<"bad-request">>), + {ok, mod_sasl2:update_inline_request(SaslAcc, ?MODULE, SimpleErrorStanza, failure)}; + {error, ErrorStanza} -> + {ok, mod_sasl2:update_inline_request(SaslAcc, ?MODULE, ErrorStanza, failure)}; + {ok, #{resumed := Resumed, forward := ToForward}, SaslAcc1} -> + SaslAcc2 = mod_sasl2:update_inline_request(SaslAcc1, ?MODULE, Resumed, success), + {ok, mongoose_c2s_acc:to_acc(SaslAcc2, socket_send, ToForward)} + end. + -spec sm() -> exml:element(). sm() -> #xmlel{name = <<"sm">>, diff --git a/test/common/config_parser_helper.erl b/test/common/config_parser_helper.erl index af3d57867ea..4bdc2b3731a 100644 --- a/test/common/config_parser_helper.erl +++ b/test/common/config_parser_helper.erl @@ -984,6 +984,8 @@ default_mod_config(mod_register) -> password_strength => 0, ip_access => []}; default_mod_config(mod_roster) -> #{iqdisc => one_queue, versioning => false, store_current_id => false, backend => mnesia}; +default_mod_config(mod_sasl) -> + #{}; default_mod_config(mod_shared_roster_ldap) -> #{pool_tag => default, deref => never, filter => <<"">>, groupattr => <<"cn">>, groupdesc => <<"cn">>, userdesc => <<"cn">>, useruid => <<"cn">>,