Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deps/rabbitmq_mqtt/include/mqtt_machine.hrl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
%% Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
%%

-record(machine_state, {client_ids = #{}}).
-record(machine_state, {client_ids = #{},
pids = #{}}).
113 changes: 72 additions & 41 deletions deps/rabbitmq_mqtt/src/mqtt_machine.erl
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,49 @@ init(_Conf) ->

-spec apply(map(), command(), state()) ->
{state(), reply(), ra_machine:effects()}.
apply(_Meta, {register, ClientId, Pid}, #machine_state{client_ids = Ids} = State0) ->
{Effects, Ids1} =
apply(_Meta, {register, ClientId, Pid},
#machine_state{client_ids = Ids,
pids = Pids0} = State0) ->
{Effects, Ids1, Pids} =
case maps:find(ClientId, Ids) of
{ok, OldPid} when Pid =/= OldPid ->
Effects0 = [{demonitor, process, OldPid},
{monitor, process, Pid},
{mod_call, ?MODULE, notify_connection, [OldPid, duplicate_id]}],
{Effects0, maps:remove(ClientId, Ids)};
{mod_call, ?MODULE, notify_connection,
[OldPid, duplicate_id]}],
{Effects0, maps:remove(ClientId, Ids), Pids0};
_ ->
Pids1 = maps:update_with(Pid, fun(CIds) -> [ClientId | CIds] end,
[ClientId], Pids0),
Effects0 = [{monitor, process, Pid}],
{Effects0, Ids}
{Effects0, Ids, Pids1}
end,
State = State0#machine_state{client_ids = maps:put(ClientId, Pid, Ids1)},
State = State0#machine_state{client_ids = maps:put(ClientId, Pid, Ids1),
pids = Pids},
{State, ok, Effects};

apply(Meta, {unregister, ClientId, Pid}, #machine_state{client_ids = Ids} = State0) ->
apply(Meta, {unregister, ClientId, Pid}, #machine_state{client_ids = Ids,
pids = Pids0} = State0) ->
State = case maps:find(ClientId, Ids) of
{ok, Pid} -> State0#machine_state{client_ids = maps:remove(ClientId, Ids)};
%% don't delete client id that might belong to a newer connection
%% that kicked the one with Pid out
{ok, _AnotherPid} -> State0;
error -> State0
end,
{ok, Pid} ->
Pids = case maps:get(Pid, Pids0, undefined) of
undefined ->
Pids0;
[ClientId] ->
maps:remove(Pid, Pids0);
Cids ->
Pids0#{Pid => lists:delete(ClientId, Cids)}
end,

State0#machine_state{client_ids = maps:remove(ClientId, Ids),
pids = Pids};
%% don't delete client id that might belong to a newer connection
%% that kicked the one with Pid out
{ok, _AnotherPid} ->
State0;
error ->
State0
end,
Effects0 = [{demonitor, process, Pid}],
%% snapshot only when the map has changed
Effects = case State of
Expand All @@ -69,18 +89,21 @@ apply(_Meta, {down, DownPid, noconnection}, State) ->
Effect = {monitor, node, node(DownPid)},
{State, ok, Effect};

apply(Meta, {down, DownPid, _}, #machine_state{client_ids = Ids} = State0) ->
Ids1 = maps:filter(fun (_ClientId, Pid) when Pid =:= DownPid ->
false;
(_, _) ->
true
end, Ids),
State = State0#machine_state{client_ids = Ids1},
Delta = maps:keys(Ids) -- maps:keys(Ids1),
Effects = lists:map(fun(Id) ->
[{mod_call, rabbit_log, debug,
["MQTT connection with client id '~s' failed", [Id]]}] end, Delta),
{State, ok, Effects ++ snapshot_effects(Meta, State)};
apply(Meta, {down, DownPid, _}, #machine_state{client_ids = Ids,
pids = Pids0} = State0) ->
case maps:get(DownPid, Pids0, undefined) of
undefined ->
{State0, ok, []};
ClientIds ->
Ids1 = maps:without(ClientIds, Ids),
State = State0#machine_state{client_ids = Ids1,
pids = maps:remove(DownPid, Pids0)},
Effects = lists:map(fun(Id) ->
[{mod_call, rabbit_log, debug,
["MQTT connection with client id '~s' failed", [Id]]}]
end, ClientIds),
{State, ok, Effects ++ snapshot_effects(Meta, State)}
end;

apply(_Meta, {nodeup, Node}, State) ->
%% Work out if any pids that were disconnected are still
Expand All @@ -91,22 +114,30 @@ apply(_Meta, {nodeup, Node}, State) ->
apply(_Meta, {nodedown, _Node}, State) ->
{State, ok};

apply(Meta, {leave, Node}, #machine_state{client_ids = Ids} = State0) ->
Ids1 = maps:filter(fun (_ClientId, Pid) -> node(Pid) =/= Node end, Ids),
Delta = maps:keys(Ids) -- maps:keys(Ids1),

Effects = lists:foldl(fun (ClientId, Acc) ->
Pid = maps:get(ClientId, Ids),
[
{demonitor, process, Pid},
{mod_call, ?MODULE, notify_connection, [Pid, decommission_node]},
{mod_call, rabbit_log, debug,
["MQTT will remove client ID '~s' from known "
"as its node has been decommissioned", [ClientId]]}
] ++ Acc
end, [], Delta),

State = State0#machine_state{client_ids = Ids1},
apply(Meta, {leave, Node}, #machine_state{client_ids = Ids,
pids = Pids0} = State0) ->
{Keep, Remove} = maps:fold(
fun (ClientId, Pid, {In, Out}) ->
case node(Pid) =/= Node of
true ->
{In#{ClientId => Pid}, Out};
false ->
{In, Out#{ClientId => Pid}}
end
end, {#{}, #{}}, Ids),
Effects = maps:fold(fun (ClientId, _Pid, Acc) ->
Pid = maps:get(ClientId, Ids),
[
{demonitor, process, Pid},
{mod_call, ?MODULE, notify_connection, [Pid, decommission_node]},
{mod_call, rabbit_log, debug,
["MQTT will remove client ID '~s' from known "
"as its node has been decommissioned", [ClientId]]}
] ++ Acc
end, [], Remove),

State = State0#machine_state{client_ids = Keep,
pids = maps:without(maps:keys(Remove), Pids0)},
{State, ok, Effects ++ snapshot_effects(Meta, State)};

apply(_Meta, Unknown, State) ->
Expand Down
26 changes: 25 additions & 1 deletion deps/rabbitmq_mqtt/test/mqtt_machine_SUITE.erl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ all() ->

all_tests() ->
[
basics
basics,
many_downs
].

groups() ->
Expand Down Expand Up @@ -56,6 +57,7 @@ basics(_Config) ->
ClientId = <<"id1">>,
{S1, ok, _} = mqtt_machine:apply(meta(1), {register, ClientId, self()}, S0),
?assertMatch(#machine_state{client_ids = Ids} when map_size(Ids) == 1, S1),
?assertMatch(#machine_state{pids = Pids} when map_size(Pids) == 1, S1),
{S2, ok, _} = mqtt_machine:apply(meta(2), {register, ClientId, self()}, S1),
?assertMatch(#machine_state{client_ids = Ids} when map_size(Ids) == 1, S2),
{S3, ok, _} = mqtt_machine:apply(meta(3), {down, self(), noproc}, S2),
Expand All @@ -65,6 +67,28 @@ basics(_Config) ->

ok.

many_downs(_Config) ->
S0 = mqtt_machine:init(#{}),
Clients = [{list_to_binary(integer_to_list(I)), spawn(fun() -> ok end)}
|| I <- lists:seq(1, 10000)],
S1 = lists:foldl(
fun ({ClientId, Pid}, Acc0) ->
{Acc, ok, _} = mqtt_machine:apply(meta(1), {register, ClientId, Pid}, Acc0),
Acc
end, S0, Clients),
_ = lists:foldl(
fun ({_ClientId, Pid}, Acc0) ->
{Acc, ok, _} = mqtt_machine:apply(meta(1), {down, Pid, noproc}, Acc0),
Acc
end, S1, Clients),
_ = lists:foldl(
fun ({ClientId, Pid}, Acc0) ->
{Acc, ok, _} = mqtt_machine:apply(meta(1), {unregister, ClientId,
Pid}, Acc0),
Acc
end, S0, Clients),

ok.
%% Utility

meta(Idx) ->
Expand Down