Skip to content

Commit 9653d50

Browse files
Merge pull request #2692 from rabbitmq/mqtt-machine-opt
Optimise MQTT state machine (cherry picked from commit de02be2)
1 parent 72fabc8 commit 9653d50

File tree

3 files changed

+99
-43
lines changed

3 files changed

+99
-43
lines changed

deps/rabbitmq_mqtt/include/mqtt_machine.hrl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
%% Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
66
%%
77

8-
-record(machine_state, {client_ids = #{}}).
8+
-record(machine_state, {client_ids = #{},
9+
pids = #{}}).

deps/rabbitmq_mqtt/src/mqtt_machine.erl

Lines changed: 72 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -31,29 +31,49 @@ init(_Conf) ->
3131

3232
-spec apply(map(), command(), state()) ->
3333
{state(), reply(), ra_machine:effects()}.
34-
apply(_Meta, {register, ClientId, Pid}, #machine_state{client_ids = Ids} = State0) ->
35-
{Effects, Ids1} =
34+
apply(_Meta, {register, ClientId, Pid},
35+
#machine_state{client_ids = Ids,
36+
pids = Pids0} = State0) ->
37+
{Effects, Ids1, Pids} =
3638
case maps:find(ClientId, Ids) of
3739
{ok, OldPid} when Pid =/= OldPid ->
3840
Effects0 = [{demonitor, process, OldPid},
3941
{monitor, process, Pid},
40-
{mod_call, ?MODULE, notify_connection, [OldPid, duplicate_id]}],
41-
{Effects0, maps:remove(ClientId, Ids)};
42+
{mod_call, ?MODULE, notify_connection,
43+
[OldPid, duplicate_id]}],
44+
{Effects0, maps:remove(ClientId, Ids), Pids0};
4245
_ ->
46+
Pids1 = maps:update_with(Pid, fun(CIds) -> [ClientId | CIds] end,
47+
[ClientId], Pids0),
4348
Effects0 = [{monitor, process, Pid}],
44-
{Effects0, Ids}
49+
{Effects0, Ids, Pids1}
4550
end,
46-
State = State0#machine_state{client_ids = maps:put(ClientId, Pid, Ids1)},
51+
State = State0#machine_state{client_ids = maps:put(ClientId, Pid, Ids1),
52+
pids = Pids},
4753
{State, ok, Effects};
4854

49-
apply(Meta, {unregister, ClientId, Pid}, #machine_state{client_ids = Ids} = State0) ->
55+
apply(Meta, {unregister, ClientId, Pid}, #machine_state{client_ids = Ids,
56+
pids = Pids0} = State0) ->
5057
State = case maps:find(ClientId, Ids) of
51-
{ok, Pid} -> State0#machine_state{client_ids = maps:remove(ClientId, Ids)};
52-
%% don't delete client id that might belong to a newer connection
53-
%% that kicked the one with Pid out
54-
{ok, _AnotherPid} -> State0;
55-
error -> State0
56-
end,
58+
{ok, Pid} ->
59+
Pids = case maps:get(Pid, Pids0, undefined) of
60+
undefined ->
61+
Pids0;
62+
[ClientId] ->
63+
maps:remove(Pid, Pids0);
64+
Cids ->
65+
Pids0#{Pid => lists:delete(ClientId, Cids)}
66+
end,
67+
68+
State0#machine_state{client_ids = maps:remove(ClientId, Ids),
69+
pids = Pids};
70+
%% don't delete client id that might belong to a newer connection
71+
%% that kicked the one with Pid out
72+
{ok, _AnotherPid} ->
73+
State0;
74+
error ->
75+
State0
76+
end,
5777
Effects0 = [{demonitor, process, Pid}],
5878
%% snapshot only when the map has changed
5979
Effects = case State of
@@ -69,18 +89,21 @@ apply(_Meta, {down, DownPid, noconnection}, State) ->
6989
Effect = {monitor, node, node(DownPid)},
7090
{State, ok, Effect};
7191

72-
apply(Meta, {down, DownPid, _}, #machine_state{client_ids = Ids} = State0) ->
73-
Ids1 = maps:filter(fun (_ClientId, Pid) when Pid =:= DownPid ->
74-
false;
75-
(_, _) ->
76-
true
77-
end, Ids),
78-
State = State0#machine_state{client_ids = Ids1},
79-
Delta = maps:keys(Ids) -- maps:keys(Ids1),
80-
Effects = lists:map(fun(Id) ->
81-
[{mod_call, rabbit_log, debug,
82-
["MQTT connection with client id '~s' failed", [Id]]}] end, Delta),
83-
{State, ok, Effects ++ snapshot_effects(Meta, State)};
92+
apply(Meta, {down, DownPid, _}, #machine_state{client_ids = Ids,
93+
pids = Pids0} = State0) ->
94+
case maps:get(DownPid, Pids0, undefined) of
95+
undefined ->
96+
{State0, ok, []};
97+
ClientIds ->
98+
Ids1 = maps:without(ClientIds, Ids),
99+
State = State0#machine_state{client_ids = Ids1,
100+
pids = maps:remove(DownPid, Pids0)},
101+
Effects = lists:map(fun(Id) ->
102+
[{mod_call, rabbit_log, debug,
103+
["MQTT connection with client id '~s' failed", [Id]]}]
104+
end, ClientIds),
105+
{State, ok, Effects ++ snapshot_effects(Meta, State)}
106+
end;
84107

85108
apply(_Meta, {nodeup, Node}, State) ->
86109
%% Work out if any pids that were disconnected are still
@@ -91,22 +114,30 @@ apply(_Meta, {nodeup, Node}, State) ->
91114
apply(_Meta, {nodedown, _Node}, State) ->
92115
{State, ok};
93116

94-
apply(Meta, {leave, Node}, #machine_state{client_ids = Ids} = State0) ->
95-
Ids1 = maps:filter(fun (_ClientId, Pid) -> node(Pid) =/= Node end, Ids),
96-
Delta = maps:keys(Ids) -- maps:keys(Ids1),
97-
98-
Effects = lists:foldl(fun (ClientId, Acc) ->
99-
Pid = maps:get(ClientId, Ids),
100-
[
101-
{demonitor, process, Pid},
102-
{mod_call, ?MODULE, notify_connection, [Pid, decommission_node]},
103-
{mod_call, rabbit_log, debug,
104-
["MQTT will remove client ID '~s' from known "
105-
"as its node has been decommissioned", [ClientId]]}
106-
] ++ Acc
107-
end, [], Delta),
108-
109-
State = State0#machine_state{client_ids = Ids1},
117+
apply(Meta, {leave, Node}, #machine_state{client_ids = Ids,
118+
pids = Pids0} = State0) ->
119+
{Keep, Remove} = maps:fold(
120+
fun (ClientId, Pid, {In, Out}) ->
121+
case node(Pid) =/= Node of
122+
true ->
123+
{In#{ClientId => Pid}, Out};
124+
false ->
125+
{In, Out#{ClientId => Pid}}
126+
end
127+
end, {#{}, #{}}, Ids),
128+
Effects = maps:fold(fun (ClientId, _Pid, Acc) ->
129+
Pid = maps:get(ClientId, Ids),
130+
[
131+
{demonitor, process, Pid},
132+
{mod_call, ?MODULE, notify_connection, [Pid, decommission_node]},
133+
{mod_call, rabbit_log, debug,
134+
["MQTT will remove client ID '~s' from known "
135+
"as its node has been decommissioned", [ClientId]]}
136+
] ++ Acc
137+
end, [], Remove),
138+
139+
State = State0#machine_state{client_ids = Keep,
140+
pids = maps:without(maps:keys(Remove), Pids0)},
110141
{State, ok, Effects ++ snapshot_effects(Meta, State)};
111142

112143
apply(_Meta, Unknown, State) ->

deps/rabbitmq_mqtt/test/mqtt_machine_SUITE.erl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ all() ->
2121

2222
all_tests() ->
2323
[
24-
basics
24+
basics,
25+
many_downs
2526
].
2627

2728
groups() ->
@@ -56,6 +57,7 @@ basics(_Config) ->
5657
ClientId = <<"id1">>,
5758
{S1, ok, _} = mqtt_machine:apply(meta(1), {register, ClientId, self()}, S0),
5859
?assertMatch(#machine_state{client_ids = Ids} when map_size(Ids) == 1, S1),
60+
?assertMatch(#machine_state{pids = Pids} when map_size(Pids) == 1, S1),
5961
{S2, ok, _} = mqtt_machine:apply(meta(2), {register, ClientId, self()}, S1),
6062
?assertMatch(#machine_state{client_ids = Ids} when map_size(Ids) == 1, S2),
6163
{S3, ok, _} = mqtt_machine:apply(meta(3), {down, self(), noproc}, S2),
@@ -65,6 +67,28 @@ basics(_Config) ->
6567

6668
ok.
6769

70+
many_downs(_Config) ->
71+
S0 = mqtt_machine:init(#{}),
72+
Clients = [{list_to_binary(integer_to_list(I)), spawn(fun() -> ok end)}
73+
|| I <- lists:seq(1, 10000)],
74+
S1 = lists:foldl(
75+
fun ({ClientId, Pid}, Acc0) ->
76+
{Acc, ok, _} = mqtt_machine:apply(meta(1), {register, ClientId, Pid}, Acc0),
77+
Acc
78+
end, S0, Clients),
79+
_ = lists:foldl(
80+
fun ({_ClientId, Pid}, Acc0) ->
81+
{Acc, ok, _} = mqtt_machine:apply(meta(1), {down, Pid, noproc}, Acc0),
82+
Acc
83+
end, S1, Clients),
84+
_ = lists:foldl(
85+
fun ({ClientId, Pid}, Acc0) ->
86+
{Acc, ok, _} = mqtt_machine:apply(meta(1), {unregister, ClientId,
87+
Pid}, Acc0),
88+
Acc
89+
end, S0, Clients),
90+
91+
ok.
6892
%% Utility
6993

7094
meta(Idx) ->

0 commit comments

Comments
 (0)