Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deps/rabbitmq_mqtt/include/mqtt_machine.hrl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@
%%

-record(machine_state, {client_ids = #{},
pids = #{}}).
pids = #{},
%% add acouple of fields for future extensibility
reserved_1,
reserved_2}).

8 changes: 8 additions & 0 deletions deps/rabbitmq_mqtt/include/mqtt_machine_v0.hrl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
%% This Source Code Form is subject to the terms of the Mozilla Public
%% License, v. 2.0. If a copy of the MPL was not distributed with this
%% file, You can obtain one at https://mozilla.org/MPL/2.0/.
%%
%% Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
%%

-record(machine_state, {client_ids = #{}}).
43 changes: 35 additions & 8 deletions deps/rabbitmq_mqtt/src/mqtt_machine.erl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

-include("mqtt_machine.hrl").

-export([init/1,
-export([version/0,
which_module/1,
init/1,
apply/3,
state_enter/2,
notify_connection/2]).
Expand All @@ -24,6 +26,10 @@
-type command() :: {register, client_id(), pid()} |
{unregister, client_id(), pid()} |
list.
version() -> 1.

which_module(1) -> ?MODULE;
which_module(0) -> mqtt_machine_v0.

-spec init(config()) -> state().
init(_Conf) ->
Expand All @@ -41,12 +47,25 @@ apply(_Meta, {register, ClientId, Pid},
{monitor, process, Pid},
{mod_call, ?MODULE, notify_connection,
[OldPid, duplicate_id]}],
{Effects0, maps:remove(ClientId, Ids), Pids0};
_ ->
Pids1 = maps:update_with(Pid, fun(CIds) -> [ClientId | CIds] end,
Pids2 = case maps:take(OldPid, Pids0) of
error ->
Pids0;
{[ClientId], Pids1} ->
Pids1;
{ClientIds, Pids1} ->
Pids1#{ClientId => lists:delete(ClientId, ClientIds)}
end,
Pids3 = maps:update_with(Pid, fun(CIds) -> [ClientId | CIds] end,
[ClientId], Pids2),
{Effects0, maps:remove(ClientId, Ids), Pids3};

{ok, Pid} ->
{[], Ids, Pids0};
error ->
Pids1 = maps:update_with(Pid, fun(CIds) -> [ClientId | CIds] end,
[ClientId], Pids0),
Effects0 = [{monitor, process, Pid}],
{Effects0, Ids, Pids1}
Effects0 = [{monitor, process, Pid}],
{Effects0, Ids, Pids1}
end,
State = State0#machine_state{client_ids = maps:put(ClientId, Pid, Ids1),
pids = Pids},
Expand Down Expand Up @@ -139,9 +158,17 @@ apply(Meta, {leave, Node}, #machine_state{client_ids = Ids,
State = State0#machine_state{client_ids = Keep,
pids = maps:without(maps:keys(Remove), Pids0)},
{State, ok, Effects ++ snapshot_effects(Meta, State)};

apply(_Meta, {machine_version, 0, 1}, {machine_state, Ids}) ->
Pids = maps:fold(
fun(Id, Pid, Acc) ->
maps:update_with(Pid,
fun(CIds) -> [Id | CIds] end,
[Id], Acc)
end, #{}, Ids),
{#machine_state{client_ids = Ids,
pids = Pids}, ok, []};
apply(_Meta, Unknown, State) ->
error_logger:error_msg("MQTT Raft state machine received unknown command ~p~n", [Unknown]),
error_logger:error_msg("MQTT Raft state machine v1 received unknown command ~p~n", [Unknown]),
{State, {error, {unknown_command, Unknown}}, []}.

state_enter(leader, State) ->
Expand Down
134 changes: 134 additions & 0 deletions deps/rabbitmq_mqtt/src/mqtt_machine_v0.erl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
%% This Source Code Form is subject to the terms of the Mozilla Public
%% License, v. 2.0. If a copy of the MPL was not distributed with this
%% file, You can obtain one at https://mozilla.org/MPL/2.0/.
%%
%% Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
%%
-module(mqtt_machine_v0).
-behaviour(ra_machine).

-include("mqtt_machine_v0.hrl").

-export([init/1,
apply/3,
state_enter/2,
notify_connection/2]).

-type state() :: #machine_state{}.

-type config() :: map().

-type reply() :: {ok, term()} | {error, term()}.
-type client_id() :: term().

-type command() :: {register, client_id(), pid()} |
{unregister, client_id(), pid()} |
list.

-spec init(config()) -> state().
init(_Conf) ->
#machine_state{}.

-spec apply(map(), command(), state()) ->
{state(), reply(), ra_machine:effects()}.
apply(_Meta, {register, ClientId, Pid}, #machine_state{client_ids = Ids} = State0) ->
{Effects, Ids1} =
case maps:find(ClientId, Ids) of
{ok, OldPid} when Pid =/= OldPid ->
Effects0 = [{demonitor, process, OldPid},
{monitor, process, Pid},
{mod_call, ?MODULE, notify_connection, [OldPid, duplicate_id]}],
{Effects0, maps:remove(ClientId, Ids)};
_ ->
Effects0 = [{monitor, process, Pid}],
{Effects0, Ids}
end,
State = State0#machine_state{client_ids = maps:put(ClientId, Pid, Ids1)},
{State, ok, Effects};

apply(Meta, {unregister, ClientId, Pid}, #machine_state{client_ids = Ids} = State0) ->
State = case maps:find(ClientId, Ids) of
{ok, Pid} -> State0#machine_state{client_ids = maps:remove(ClientId, Ids)};
%% don't delete client id that might belong to a newer connection
%% that kicked the one with Pid out
{ok, _AnotherPid} -> State0;
error -> State0
end,
Effects0 = [{demonitor, process, Pid}],
%% snapshot only when the map has changed
Effects = case State of
State0 -> Effects0;
_ -> Effects0 ++ snapshot_effects(Meta, State)
end,
{State, ok, Effects};

apply(_Meta, {down, DownPid, noconnection}, State) ->
%% Monitor the node the pid is on (see {nodeup, Node} below)
%% so that we can detect when the node is re-connected and discover the
%% actual fate of the connection processes on it
Effect = {monitor, node, node(DownPid)},
{State, ok, Effect};

apply(Meta, {down, DownPid, _}, #machine_state{client_ids = Ids} = State0) ->
Ids1 = maps:filter(fun (_ClientId, Pid) when Pid =:= DownPid ->
false;
(_, _) ->
true
end, Ids),
State = State0#machine_state{client_ids = Ids1},
Delta = maps:keys(Ids) -- maps:keys(Ids1),
Effects = lists:map(fun(Id) ->
[{mod_call, rabbit_log, debug,
["MQTT connection with client id '~s' failed", [Id]]}] end, Delta),
{State, ok, Effects ++ snapshot_effects(Meta, State)};

apply(_Meta, {nodeup, Node}, State) ->
%% Work out if any pids that were disconnected are still
%% alive.
%% Re-request the monitor for the pids on the now-back node.
Effects = [{monitor, process, Pid} || Pid <- all_pids(State), node(Pid) == Node],
{State, ok, Effects};
apply(_Meta, {nodedown, _Node}, State) ->
{State, ok};

apply(Meta, {leave, Node}, #machine_state{client_ids = Ids} = State0) ->
Ids1 = maps:filter(fun (_ClientId, Pid) -> node(Pid) =/= Node end, Ids),
Delta = maps:keys(Ids) -- maps:keys(Ids1),

Effects = lists:foldl(fun (ClientId, Acc) ->
Pid = maps:get(ClientId, Ids),
[
{demonitor, process, Pid},
{mod_call, ?MODULE, notify_connection, [Pid, decommission_node]},
{mod_call, rabbit_log, debug,
["MQTT will remove client ID '~s' from known "
"as its node has been decommissioned", [ClientId]]}
] ++ Acc
end, [], Delta),

State = State0#machine_state{client_ids = Ids1},
{State, ok, Effects ++ snapshot_effects(Meta, State)};

apply(_Meta, Unknown, State) ->
error_logger:error_msg("MQTT Raft state machine received unknown command ~p~n", [Unknown]),
{State, {error, {unknown_command, Unknown}}, []}.

state_enter(leader, State) ->
%% re-request monitors for all known pids, this would clean up
%% records for all connections are no longer around, e.g. right after node restart
[{monitor, process, Pid} || Pid <- all_pids(State)];
state_enter(_, _) ->
[].

%% ==========================

%% Avoids blocking the Raft leader.
notify_connection(Pid, Reason) ->
spawn(fun() -> gen_server2:cast(Pid, Reason) end).

-spec snapshot_effects(map(), state()) -> ra_machine:effects().
snapshot_effects(#{index := RaftIdx}, State) ->
[{release_cursor, RaftIdx, State}].

all_pids(#machine_state{client_ids = Ids}) ->
maps:values(Ids).
28 changes: 24 additions & 4 deletions deps/rabbitmq_mqtt/test/mqtt_machine_SUITE.erl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ all() ->
all_tests() ->
[
basics,
machine_upgrade,
many_downs
].

Expand Down Expand Up @@ -55,18 +56,37 @@ end_per_testcase(_TestCase, _Config) ->
basics(_Config) ->
S0 = mqtt_machine:init(#{}),
ClientId = <<"id1">>,
OthPid = spawn(fun () -> ok end),
{S1, ok, _} = mqtt_machine:apply(meta(1), {register, ClientId, self()}, S0),
?assertMatch(#machine_state{client_ids = Ids} when map_size(Ids) == 1, S1),
?assertMatch(#machine_state{pids = Pids} when map_size(Pids) == 1, S1),
{S2, ok, _} = mqtt_machine:apply(meta(2), {register, ClientId, self()}, S1),
?assertMatch(#machine_state{client_ids = Ids} when map_size(Ids) == 1, S2),
{S3, ok, _} = mqtt_machine:apply(meta(3), {down, self(), noproc}, S2),
{S2, ok, _} = mqtt_machine:apply(meta(2), {register, ClientId, OthPid}, S1),
?assertMatch(#machine_state{client_ids = #{ClientId := OthPid} = Ids}
when map_size(Ids) == 1, S2),
{S3, ok, _} = mqtt_machine:apply(meta(3), {down, OthPid, noproc}, S2),
?assertMatch(#machine_state{client_ids = Ids} when map_size(Ids) == 0, S3),
{S4, ok, _} = mqtt_machine:apply(meta(3), {unregister, ClientId, self()}, S2),
{S4, ok, _} = mqtt_machine:apply(meta(3), {unregister, ClientId, OthPid}, S2),
?assertMatch(#machine_state{client_ids = Ids} when map_size(Ids) == 0, S4),

ok.

machine_upgrade(_Config) ->
S0 = mqtt_machine_v0:init(#{}),
ClientId = <<"id1">>,
Self = self(),
{S1, ok, _} = mqtt_machine_v0:apply(meta(1), {register, ClientId, self()}, S0),
?assertMatch({machine_state, Ids} when map_size(Ids) == 1, S1),
{S2, ok, _} = mqtt_machine:apply(meta(2), {machine_version, 0, 1}, S1),
?assertMatch(#machine_state{client_ids = #{ClientId := Self},
pids = #{Self := [ClientId]} = Pids}
when map_size(Pids) == 1, S2),
{S3, ok, _} = mqtt_machine:apply(meta(3), {down, self(), noproc}, S2),
?assertMatch(#machine_state{client_ids = Ids,
pids = Pids}
when map_size(Ids) == 0 andalso map_size(Pids) == 0, S3),

ok.

many_downs(_Config) ->
S0 = mqtt_machine:init(#{}),
Clients = [{list_to_binary(integer_to_list(I)), spawn(fun() -> ok end)}
Expand Down