diff --git a/src/consensus/raft/raft.h b/src/consensus/raft/raft.h index 53865869b91e..2b89c0af57d3 100644 --- a/src/consensus/raft/raft.h +++ b/src/consensus/raft/raft.h @@ -768,8 +768,8 @@ namespace raft RequestVote rv = {raft_request_vote, local_id, current_term, - last_idx, - get_term_internal(last_idx)}; + commit_idx, + get_term_internal(commit_idx)}; channels->send_authenticated(ccf::NodeMsgType::consensus_msg, to, rv); } @@ -826,10 +826,11 @@ namespace raft } // If the candidate's log is at least as up-to-date as ours, vote yes - auto last_term = get_term_internal(last_idx); + auto last_commit_term = get_term_internal(commit_idx); - auto answer = (r.last_log_term > last_term) || - ((r.last_log_term == last_term) && (r.last_log_idx >= last_idx)); + auto answer = (r.last_commit_term > last_commit_term) || + ((r.last_commit_term == last_commit_term) && + (r.last_commit_idx >= commit_idx)); if (answer) { @@ -1116,6 +1117,8 @@ namespace raft void rollback(Index idx) { store->rollback(idx); + ledger->truncate(idx); + last_idx = idx; LOG_DEBUG_FMT("Rolled back at {}", idx); while (!committable_indices.empty() && (committable_indices.back() > idx)) diff --git a/src/consensus/raft/rafttypes.h b/src/consensus/raft/rafttypes.h index ec163cb8b090..e4ea6e002de3 100644 --- a/src/consensus/raft/rafttypes.h +++ b/src/consensus/raft/rafttypes.h @@ -109,8 +109,12 @@ namespace raft struct RequestVote : RaftHeader { Term term; - Index last_log_idx; - Term last_log_term; + // last_log_idx in vanilla raft but last_commit_idx here to preserve + // verifiability + Index last_commit_idx; + // last_log_term in vanilla raft but last_commit_term here to preserve + // verifiability + Term last_commit_term; }; struct RequestVoteResponse : RaftHeader diff --git a/src/consensus/raft/test/driver.h b/src/consensus/raft/test/driver.h index 4dad3a502076..20137e099a4a 100644 --- a/src/consensus/raft/test/driver.h +++ b/src/consensus/raft/test/driver.h @@ -76,8 +76,8 @@ class RaftDriver raft::NodeId node_id, raft::NodeId tgt_node_id, raft::RequestVote rv) { std::ostringstream s; - s << "request_vote t: " << rv.term << ", lli: " << rv.last_log_idx - << ", llt: " << rv.last_log_term; + s << "request_vote t: " << rv.term << ", lli: " << rv.last_commit_idx + << ", llt: " << rv.last_commit_term; log(node_id, tgt_node_id, s.str()); } diff --git a/src/consensus/raft/test/main.cpp b/src/consensus/raft/test/main.cpp index 587f82e8d386..3adb9cb1a29a 100644 --- a/src/consensus/raft/test/main.cpp +++ b/src/consensus/raft/test/main.cpp @@ -139,8 +139,8 @@ TEST_CASE( REQUIRE(get<0>(rv) == node_id1); auto rvc = get<1>(rv); REQUIRE(rvc.term == 1); - REQUIRE(rvc.last_log_idx == 0); - REQUIRE(rvc.last_log_term == 0); + REQUIRE(rvc.last_commit_idx == 0); + REQUIRE(rvc.last_commit_term == 0); r1.recv_message(reinterpret_cast(&rvc), sizeof(rvc)); @@ -151,8 +151,8 @@ TEST_CASE( REQUIRE(get<0>(rv) == node_id2); rvc = get<1>(rv); REQUIRE(rvc.term == 1); - REQUIRE(rvc.last_log_idx == 0); - REQUIRE(rvc.last_log_term == 0); + REQUIRE(rvc.last_commit_idx == 0); + REQUIRE(rvc.last_commit_term == 0); r2.recv_message(reinterpret_cast(&rvc), sizeof(rvc)); diff --git a/tests/infra/ccf.py b/tests/infra/ccf.py index 16f090872a4c..33f0bc21a3dd 100644 --- a/tests/infra/ccf.py +++ b/tests/infra/ccf.py @@ -826,6 +826,12 @@ def start(self, lib_name, enclave_type, workspace, label, members_certs, **kwarg ) self.network_state = NodeNetworkState.joined + def suspend(self): + self.remote.suspend() + + def resume(self): + self.remote.resume() + def join( self, lib_name, enclave_type, workspace, label, target_rpc_address, **kwargs ): diff --git a/tests/infra/remote.py b/tests/infra/remote.py index e27be84e9a2a..c9353e0e5f8f 100644 --- a/tests/infra/remote.py +++ b/tests/infra/remote.py @@ -135,12 +135,16 @@ def __init__( self.files += data_files self.cmd = cmd self.client = paramiko.SSHClient() + # this client (proc_client) is used to execute commands on the remote host since the main client uses pty + self.proc_client = paramiko.SSHClient() self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + self.proc_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.root = os.path.join(workspace, label + "_" + name) self.name = name self.env = env or {} self.out = os.path.join(self.root, "out") self.err = os.path.join(self.root, "err") + self.suspension_proc = None def _rc(self, cmd): LOG.info("[{}] {}".format(self.hostname, cmd)) @@ -150,6 +154,7 @@ def _rc(self, cmd): def _connect(self): LOG.debug("[{}] connect".format(self.hostname)) self.client.connect(self.hostname) + self.proc_client.connect(self.hostname) def _setup_files(self): assert self._rc("rm -rf {}".format(self.root)) == 0 @@ -231,6 +236,21 @@ def start(self): cmd = self._cmd() LOG.info("[{}] {}".format(self.hostname, cmd)) stdin, stdout, stderr = self.client.exec_command(cmd, get_pty=True) + _, stdout_, _ = self.proc_client.exec_command(f'ps -ef | pgrep -f "{cmd}"') + self.pid = stdout_.readline() + + def suspend(self): + _, stdout, _ = self.proc_client.exec_command(f"kill -STOP {self.pid}") + if stdout.channel.recv_exit_status() == 0: + LOG.info(f"Node {self.name} suspended...") + else: + raise RuntimeError(f"Node {self.name} could not be suspended") + + def resume(self): + _, stdout, _ = self.proc_client.exec_command(f"kill -CONT {self.pid}") + if stdout.channel.recv_exit_status() != 0: + raise RuntimeError(f"Could not resume node {self.name} from suspension!") + LOG.info(f"Node {self.name} resuming from suspension...") def stop(self): """ @@ -243,6 +263,7 @@ def stop(self): "{}_err_{}".format(self.hostname, self.name), ) self.client.close() + self.proc_client.close() def setup(self): """ @@ -392,6 +413,14 @@ def start(self, timeout=10): env=self.env, ) + def suspend(self): + self.proc.send_signal(signal.SIGSTOP) + LOG.info(f"Node {self.name} suspended...") + + def resume(self): + self.proc.send_signal(signal.SIGCONT) + LOG.info(f"Node {self.name} resuming from suspension...") + def stop(self): """ Disconnect the client, and therefore shut down the command as well. @@ -617,6 +646,12 @@ def setup(self): def start(self): self.remote.start() + def suspend(self): + return self.remote.suspend() + + def resume(self): + self.remote.resume() + def get_startup_files(self): self.remote.get(self.pem) if self.start_type in {StartType.new, StartType.recover}: diff --git a/tests/node_suspension.py b/tests/node_suspension.py new file mode 100644 index 000000000000..84c722ba6bdf --- /dev/null +++ b/tests/node_suspension.py @@ -0,0 +1,130 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the Apache 2.0 License. +import os +import getpass +import time +import logging +import multiprocessing +import shutil +from random import seed +import infra.ccf +import infra.proc +import infra.jsonrpc +import infra.notification +import infra.net +import e2e_args +from threading import Timer +import random +import contextlib + +from loguru import logger as LOG + +# 256 is the number of most recent messages that PBFT keeps in memory before needing to replay the ledger +TOTAL_REQUESTS = 256 + + +def timeout(node, suspend, election_timeout): + if suspend: + # We want to suspend the nodes' process so we need to initiate a new timer to wake it up eventually + node.suspend() + next_timeout = random.uniform(2 * election_timeout, 3 * election_timeout) + LOG.info(f"New timer set for node {node.node_id} is {next_timeout} seconds") + t = Timer(next_timeout, timeout, args=[node, False, 0]) + t.start() + else: + node.resume() + + +def run(args): + hosts = ["localhost", "localhost", "localhost"] + + with infra.ccf.network( + hosts, args.build_dir, args.debug_nodes, args.perf_nodes, pdb=args.pdb + ) as network: + first_node, (backups) = network.start_and_join(args) + + term_info = {} + long_msg = "X" * (2 ** 14) + + # first timer determines after how many seconds each node will be suspended + timeouts = [] + t = random.uniform(1, 10) + LOG.info(f"Initial timer for node {first_node.node_id} is {t} seconds...") + timeouts.append((t, first_node)) + for backup in backups: + t = random.uniform(1, 10) + LOG.info(f"Initial timer for node {backup.node_id} is {t} seconds...") + timeouts.append((t, backup)) + + for t, node in timeouts: + tm = Timer(t, timeout, args=[node, True, args.election_timeout / 1000],) + tm.start() + + with first_node.node_client() as mc: + check_commit = infra.ccf.Checker(mc) + check = infra.ccf.Checker() + + clients = [] + with contextlib.ExitStack() as es: + LOG.info("Write messages to nodes using round robin") + clients.append(es.enter_context(first_node.user_client(format="json"))) + for backup in backups: + clients.append(es.enter_context(backup.user_client(format="json"))) + node_id = 0 + for id in range(1, TOTAL_REQUESTS): + node_id += 1 + c = clients[node_id % len(clients)] + try: + resp = c.rpc("LOG_record", {"id": id, "msg": long_msg}) + except Exception: + LOG.info("Trying to access a suspended node") + try: + cur_primary, cur_term = network.find_primary() + term_info[cur_term] = cur_primary.node_id + except Exception: + LOG.info("Trying to access a suspended node") + id += 1 + + # wait for the last request to commit + final_msg = "Hello world!" + check_commit( + c.rpc("LOG_record", {"id": 1000, "msg": final_msg}), result=True, + ) + check( + c.rpc("LOG_get", {"id": 1000}), result={"msg": final_msg}, + ) + + # check that a new node can catch up after all the requests + new_node = network.create_and_trust_node( + lib_name=args.package, host="localhost", args=args, + ) + assert new_node + + # give new_node a second to catch up + time.sleep(1) + + with new_node.user_client(format="json") as c: + check( + c.rpc("LOG_get", {"id": 1000}), result={"msg": final_msg}, + ) + + # assert that view changes actually did occur + assert len(term_info) > 1 + + LOG.success("----------- terms and primaries recorded -----------") + for term, primary in term_info.items(): + LOG.success(f"term {term} - primary {primary}") + + +if __name__ == "__main__": + + args = e2e_args.cli_args() + args.package = args.app_script and "libluagenericenc" or "libloggingenc" + + notify_server_host = "localhost" + args.notify_server = ( + notify_server_host + + ":" + + str(infra.net.probably_free_local_port(notify_server_host)) + ) + run(args)