Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding node suspension test and small raft fixes #503

Merged
merged 10 commits into from
Nov 6, 2019
13 changes: 8 additions & 5 deletions src/consensus/raft/raft.h
Original file line number Diff line number Diff line change
Expand Up @@ -768,8 +768,8 @@ namespace raft
RequestVote rv = {raft_request_vote,
local_id,
current_term,
last_idx,
get_term_internal(last_idx)};
commit_idx,
get_term_internal(commit_idx)};

channels->send_authenticated(ccf::NodeMsgType::consensus_msg, to, rv);
}
Expand Down Expand Up @@ -826,10 +826,11 @@ namespace raft
}

// If the candidate's log is at least as up-to-date as ours, vote yes
auto last_term = get_term_internal(last_idx);
auto last_commit_term = get_term_internal(commit_idx);

auto answer = (r.last_log_term > last_term) ||
((r.last_log_term == last_term) && (r.last_log_idx >= last_idx));
auto answer = (r.last_commit_term > last_commit_term) ||
((r.last_commit_term == last_commit_term) &&
(r.last_commit_idx >= commit_idx));

if (answer)
{
Expand Down Expand Up @@ -1116,6 +1117,8 @@ namespace raft
void rollback(Index idx)
{
store->rollback(idx);
ledger->truncate(idx);
last_idx = idx;
LOG_DEBUG_FMT("Rolled back at {}", idx);

while (!committable_indices.empty() && (committable_indices.back() > idx))
Expand Down
8 changes: 6 additions & 2 deletions src/consensus/raft/rafttypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,12 @@ namespace raft
struct RequestVote : RaftHeader
{
Term term;
Index last_log_idx;
Term last_log_term;
olgavrou marked this conversation as resolved.
Show resolved Hide resolved
// last_log_idx in vanilla raft but last_commit_idx here to preserve
// verifiability
Index last_commit_idx;
// last_log_term in vanilla raft but last_commit_term here to preserve
// verifiability
Term last_commit_term;
};

struct RequestVoteResponse : RaftHeader
Expand Down
4 changes: 2 additions & 2 deletions src/consensus/raft/test/driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ class RaftDriver
raft::NodeId node_id, raft::NodeId tgt_node_id, raft::RequestVote rv)
{
std::ostringstream s;
s << "request_vote t: " << rv.term << ", lli: " << rv.last_log_idx
<< ", llt: " << rv.last_log_term;
s << "request_vote t: " << rv.term << ", lli: " << rv.last_commit_idx
<< ", llt: " << rv.last_commit_term;
log(node_id, tgt_node_id, s.str());
}

Expand Down
8 changes: 4 additions & 4 deletions src/consensus/raft/test/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ TEST_CASE(
REQUIRE(get<0>(rv) == node_id1);
auto rvc = get<1>(rv);
REQUIRE(rvc.term == 1);
REQUIRE(rvc.last_log_idx == 0);
REQUIRE(rvc.last_log_term == 0);
REQUIRE(rvc.last_commit_idx == 0);
REQUIRE(rvc.last_commit_term == 0);

r1.recv_message(reinterpret_cast<uint8_t*>(&rvc), sizeof(rvc));

Expand All @@ -151,8 +151,8 @@ TEST_CASE(
REQUIRE(get<0>(rv) == node_id2);
rvc = get<1>(rv);
REQUIRE(rvc.term == 1);
REQUIRE(rvc.last_log_idx == 0);
REQUIRE(rvc.last_log_term == 0);
REQUIRE(rvc.last_commit_idx == 0);
REQUIRE(rvc.last_commit_term == 0);

r2.recv_message(reinterpret_cast<uint8_t*>(&rvc), sizeof(rvc));

Expand Down
6 changes: 6 additions & 0 deletions tests/infra/ccf.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,12 @@ def start(self, lib_name, enclave_type, workspace, label, members_certs, **kwarg
)
self.network_state = NodeNetworkState.joined

def suspend(self):
self.remote.suspend()

def resume(self):
self.remote.resume()

def join(
self, lib_name, enclave_type, workspace, label, target_rpc_address, **kwargs
):
Expand Down
35 changes: 35 additions & 0 deletions tests/infra/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,16 @@ def __init__(
self.files += data_files
self.cmd = cmd
self.client = paramiko.SSHClient()
# this client (proc_client) is used to execute commands on the remote host since the main client uses pty
self.proc_client = paramiko.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.proc_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.root = os.path.join(workspace, label + "_" + name)
self.name = name
self.env = env or {}
self.out = os.path.join(self.root, "out")
self.err = os.path.join(self.root, "err")
self.suspension_proc = None

def _rc(self, cmd):
LOG.info("[{}] {}".format(self.hostname, cmd))
Expand All @@ -150,6 +154,7 @@ def _rc(self, cmd):
def _connect(self):
LOG.debug("[{}] connect".format(self.hostname))
self.client.connect(self.hostname)
self.proc_client.connect(self.hostname)

def _setup_files(self):
assert self._rc("rm -rf {}".format(self.root)) == 0
Expand Down Expand Up @@ -231,6 +236,21 @@ def start(self):
cmd = self._cmd()
LOG.info("[{}] {}".format(self.hostname, cmd))
stdin, stdout, stderr = self.client.exec_command(cmd, get_pty=True)
_, stdout_, _ = self.proc_client.exec_command(f'ps -ef | pgrep -f "{cmd}"')
olgavrou marked this conversation as resolved.
Show resolved Hide resolved
self.pid = stdout_.readline()

def suspend(self):
_, stdout, _ = self.proc_client.exec_command(f"kill -STOP {self.pid}")
if stdout.channel.recv_exit_status() == 0:
LOG.info(f"Node {self.name} suspended...")
else:
raise RuntimeError(f"Node {self.name} could not be suspended")

def resume(self):
_, stdout, _ = self.proc_client.exec_command(f"kill -CONT {self.pid}")
if stdout.channel.recv_exit_status() != 0:
raise RuntimeError(f"Could not resume node {self.name} from suspension!")
LOG.info(f"Node {self.name} resuming from suspension...")
olgavrou marked this conversation as resolved.
Show resolved Hide resolved

def stop(self):
"""
Expand All @@ -243,6 +263,7 @@ def stop(self):
"{}_err_{}".format(self.hostname, self.name),
)
self.client.close()
self.proc_client.close()

def setup(self):
"""
Expand Down Expand Up @@ -392,6 +413,14 @@ def start(self, timeout=10):
env=self.env,
)

def suspend(self):
self.proc.send_signal(signal.SIGSTOP)
LOG.info(f"Node {self.name} suspended...")

def resume(self):
self.proc.send_signal(signal.SIGCONT)
LOG.info(f"Node {self.name} resuming from suspension...")
olgavrou marked this conversation as resolved.
Show resolved Hide resolved

def stop(self):
"""
Disconnect the client, and therefore shut down the command as well.
Expand Down Expand Up @@ -617,6 +646,12 @@ def setup(self):
def start(self):
self.remote.start()

def suspend(self):
return self.remote.suspend()

def resume(self):
self.remote.resume()

def get_startup_files(self):
self.remote.get(self.pem)
if self.start_type in {StartType.new, StartType.recover}:
Expand Down
130 changes: 130 additions & 0 deletions tests/node_suspension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
olgavrou marked this conversation as resolved.
Show resolved Hide resolved
# Licensed under the Apache 2.0 License.
import os
import getpass
import time
import logging
import multiprocessing
import shutil
from random import seed
import infra.ccf
import infra.proc
import infra.jsonrpc
import infra.notification
import infra.net
import e2e_args
from threading import Timer
import random
import contextlib

from loguru import logger as LOG

# 256 is the number of most recent messages that PBFT keeps in memory before needing to replay the ledger
TOTAL_REQUESTS = 256


def timeout(node, suspend, election_timeout):
if suspend:
# We want to suspend the nodes' process so we need to initiate a new timer to wake it up eventually
node.suspend()
next_timeout = random.uniform(2 * election_timeout, 3 * election_timeout)
LOG.info(f"New timer set for node {node.node_id} is {next_timeout} seconds")
t = Timer(next_timeout, timeout, args=[node, False, 0])
t.start()
else:
node.resume()


def run(args):
hosts = ["localhost", "localhost", "localhost"]

with infra.ccf.network(
hosts, args.build_dir, args.debug_nodes, args.perf_nodes, pdb=args.pdb
) as network:
first_node, (backups) = network.start_and_join(args)

term_info = {}
long_msg = "X" * (2 ** 14)

# first timer determines after how many seconds each node will be suspended
timeouts = []
t = random.uniform(1, 10)
LOG.info(f"Initial timer for node {first_node.node_id} is {t} seconds...")
timeouts.append((t, first_node))
for backup in backups:
t = random.uniform(1, 10)
LOG.info(f"Initial timer for node {backup.node_id} is {t} seconds...")
timeouts.append((t, backup))

for t, node in timeouts:
tm = Timer(t, timeout, args=[node, True, args.election_timeout / 1000],)
tm.start()

with first_node.node_client() as mc:
check_commit = infra.ccf.Checker(mc)
check = infra.ccf.Checker()

clients = []
with contextlib.ExitStack() as es:
LOG.info("Write messages to nodes using round robin")
clients.append(es.enter_context(first_node.user_client(format="json")))
for backup in backups:
clients.append(es.enter_context(backup.user_client(format="json")))
node_id = 0
for id in range(1, TOTAL_REQUESTS):
node_id += 1
c = clients[node_id % len(clients)]
try:
resp = c.rpc("LOG_record", {"id": id, "msg": long_msg})
except Exception:
LOG.info("Trying to access a suspended node")
try:
cur_primary, cur_term = network.find_primary()
term_info[cur_term] = cur_primary.node_id
except Exception:
LOG.info("Trying to access a suspended node")
id += 1

# wait for the last request to commit
final_msg = "Hello world!"
check_commit(
c.rpc("LOG_record", {"id": 1000, "msg": final_msg}), result=True,
)
check(
c.rpc("LOG_get", {"id": 1000}), result={"msg": final_msg},
)

# check that a new node can catch up after all the requests
new_node = network.create_and_trust_node(
lib_name=args.package, host="localhost", args=args,
)
assert new_node

# give new_node a second to catch up
time.sleep(1)

with new_node.user_client(format="json") as c:
check(
c.rpc("LOG_get", {"id": 1000}), result={"msg": final_msg},
)

# assert that view changes actually did occur
assert len(term_info) > 1

LOG.success("----------- terms and primaries recorded -----------")
for term, primary in term_info.items():
LOG.success(f"term {term} - primary {primary}")


if __name__ == "__main__":

args = e2e_args.cli_args()
args.package = args.app_script and "libluagenericenc" or "libloggingenc"

notify_server_host = "localhost"
args.notify_server = (
notify_server_host
+ ":"
+ str(infra.net.probably_free_local_port(notify_server_host))
)
run(args)