Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(cli): Use SIGUSR2 to run sysctl commands
Browse files Browse the repository at this point in the history
msbrogli committed Feb 10, 2024
1 parent f4d6a28 commit 565377f
Showing 7 changed files with 182 additions and 66 deletions.
112 changes: 103 additions & 9 deletions hathor/cli/run_node.py
Original file line number Diff line number Diff line change
@@ -14,8 +14,10 @@

import os
import sys
import tempfile
from argparse import SUPPRESS, ArgumentParser, Namespace
from typing import TYPE_CHECKING, Any, Callable, Optional
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, TextIO

from pydantic import ValidationError
from structlog import get_logger
@@ -25,6 +27,25 @@

if TYPE_CHECKING:
from hathor.cli.run_node_args import RunNodeArgs
from hathor.sysctl.runner import SysctlRunner


@contextmanager
def temp_fifo(filename: str, tempdir: str | None) -> Iterator[TextIO]:
"""Context Manager for creating named pipes."""
mkfifo = getattr(os, 'mkfifo', None)
if mkfifo is None:
raise AttributeError('mkfifo is not available')

mkfifo(filename, mode=0o666)
fp = open(filename, 'r')
try:
yield fp
finally:
fp.close()
os.unlink(filename)
if tempdir is not None:
os.rmdir(tempdir)


class RunNode:
@@ -236,12 +257,79 @@ def register_signal_handlers(self) -> None:
if sigusr1 is not None:
# USR1 is available in this OS.
signal.signal(sigusr1, self.signal_usr1_handler)
sigusr2 = getattr(signal, 'SIGUSR2', None)
if sigusr2 is not None:
# USR1 is available in this OS.
signal.signal(sigusr2, self.signal_usr2_handler)

def signal_usr1_handler(self, sig: int, frame: Any) -> None:
"""Called when USR1 signal is received."""
self.log.warn('USR1 received. Killing all connections...')
if self.manager and self.manager.connections:
self.manager.connections.disconnect_all_peers(force=True)
try:
self.log.warn('USR1 received. Killing all connections...')
if self.manager and self.manager.connections:
self.manager.connections.disconnect_all_peers(force=True)
except Exception:
# see: https://docs.python.org/3/library/signal.html#note-on-signal-handlers-and-exceptions
self.log.error('prevented exception from escaping the signal handler', exc_info=True)

def signal_usr2_handler(self, sig: int, frame: Any) -> None:
"""Called when USR2 signal is received."""
try:
self.log.warn('USR2 received.')
self.run_sysctl_from_signal()
except Exception:
# see: https://docs.python.org/3/library/signal.html#note-on-signal-handlers-and-exceptions
self.log.error('prevented exception from escaping the signal handler', exc_info=True)

def run_sysctl_from_signal(self) -> None:
"""Block the main loop, get commands from a named pipe and execute then using sysctl."""
from hathor.sysctl.exception import (
SysctlEntryNotFound,
SysctlException,
SysctlReadOnlyEntry,
SysctlRunnerException,
SysctlWriteOnlyEntry,
)

runner = self.get_sysctl_runner()

if self._args.data is not None:
basedir = self._args.data
tempdir = None
else:
basedir = tempfile.mkdtemp()
tempdir = basedir

filename = os.path.join(basedir, f'SIGUSR2-{os.getpid()}.pipe')
if os.path.exists(filename):
self.log.warn('[USR2] Pipe already exists.', pipe=filename)
return

with temp_fifo(filename, tempdir) as fp:
self.log.warn('[USR2] Main loop paused, awaiting command to proceed.', pipe=filename)
lines = fp.readlines()
for cmd in lines:
cmd = cmd.strip()
self.log.warn('[USR2] Command received ', cmd=cmd)

try:
output = runner.run(cmd, require_signal_handler_safe=True)
self.log.warn('[USR2] Output', output=output)
except SysctlEntryNotFound:
path, _, _ = runner.get_line_parts(cmd)
self.log.warn('[USR2] Error', errmsg=f'{path} not found')
except SysctlReadOnlyEntry:
path, _, _ = runner.get_line_parts(cmd)
self.log.warn('[USR2] Error', errmsg=f'cannot write to {path}')
except SysctlWriteOnlyEntry:
path, _, _ = runner.get_line_parts(cmd)
self.log.warn('[USR2] Error', errmsg=f'cannot read from {path}')
except SysctlException as e:
self.log.warn('[USR2] Error', errmsg=str(e))
except ValidationError as e:
self.log.warn('[USR2] Error', errmsg=str(e))
except SysctlRunnerException as e:
self.log.warn('[USR2] Error', errmsg=str(e))

def check_unsafe_arguments(self) -> None:
unsafe_args_found = []
@@ -392,6 +480,16 @@ def __init__(self, *, argv=None):
if self._args.sysctl:
self.init_sysctl(self._args.sysctl, self._args.sysctl_init_file)

def get_sysctl_runner(self) -> 'SysctlRunner':
"""Create and return a SysctlRunner."""
from hathor.builder.sysctl_builder import SysctlBuilder
from hathor.sysctl.runner import SysctlRunner

builder = SysctlBuilder(self.artifacts)
root = builder.build()
runner = SysctlRunner(root)
return runner

def init_sysctl(self, description: str, sysctl_init_file: Optional[str] = None) -> None:
"""Initialize sysctl, listen for connections and apply settings from config file if required.
@@ -406,14 +504,10 @@ def init_sysctl(self, description: str, sysctl_init_file: Optional[str] = None)
"""
from twisted.internet.endpoints import serverFromString

from hathor.builder.sysctl_builder import SysctlBuilder
from hathor.sysctl.factory import SysctlFactory
from hathor.sysctl.init_file_loader import SysctlInitFileLoader
from hathor.sysctl.runner import SysctlRunner

builder = SysctlBuilder(self.artifacts)
root = builder.build()
runner = SysctlRunner(root)
runner = self.get_sysctl_runner()

if sysctl_init_file:
init_file_loader = SysctlInitFileLoader(runner, sysctl_init_file)
5 changes: 4 additions & 1 deletion hathor/sysctl/p2p/manager.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
from hathor.p2p.manager import ConnectionsManager
from hathor.p2p.sync_version import SyncVersion
from hathor.sysctl.exception import SysctlException
from hathor.sysctl.sysctl import Sysctl
from hathor.sysctl.sysctl import Sysctl, signal_handler_safe


def parse_text(text: str) -> list[str]:
@@ -162,6 +162,7 @@ def get_max_enabled_sync(self) -> int:
"""Return the maximum number of peers running sync simultaneously."""
return self.connections.MAX_ENABLED_SYNC

@signal_handler_safe
def set_max_enabled_sync(self, value: int) -> None:
"""Change the maximum number of peers running sync simultaneously."""
if value < 0:
@@ -179,6 +180,7 @@ def get_enabled_sync_versions(self) -> list[str]:
"""Return the list of ENABLED sync versions."""
return sorted(map(pretty_sync_version, self.connections.get_enabled_sync_versions()))

@signal_handler_safe
def set_enabled_sync_versions(self, sync_versions: list[str]) -> None:
"""Set the list of ENABLED sync versions."""
new_sync_versions = set(map(parse_sync_version, sync_versions))
@@ -202,6 +204,7 @@ def _disable_sync_version(self, sync_version: SyncVersion) -> None:
"""Disable the given sync version."""
self.connections.disable_sync_version(sync_version)

@signal_handler_safe
def set_kill_connection(self, peer_id: str, force: bool = False) -> None:
"""Kill connection with peer_id or kill all connections if peer_id == '*'."""
if peer_id == '*':
19 changes: 14 additions & 5 deletions hathor/sysctl/runner.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ class SysctlRunner:
def __init__(self, root: 'Sysctl') -> None:
self.root = root

def run(self, line: str) -> bytes:
def run(self, line: str, *, require_signal_handler_safe: bool = False) -> bytes:
"""Receives a string line, parses, interprets, acts over the Sysctl,
and returns an UTF-8 encoding data as feedback.
"""
@@ -37,23 +37,32 @@ def run(self, line: str) -> bytes:

head, separator, tail = self.get_line_parts(line)
if separator == '=':
return self._set(head, tail)
return self._set(head, tail, require_signal_handler_safe=require_signal_handler_safe)
else:
return self._get(head)

def _set(self, path: str, value_str: str) -> bytes:
def _set(self, path: str, value_str: str, *, require_signal_handler_safe: bool) -> bytes:
"""Run a `set` command in sysctl, and return and empty feedback."""
try:
value = self.deserialize(value_str)
except json.JSONDecodeError:
raise SysctlRunnerException('value: wrong format')

self.root.set(path, value)
setter = self.root.get_setter(path)
if require_signal_handler_safe:
if not hasattr(setter, '_signal_handler_safe'):
raise SysctlRunnerException('setter: not safe for signal handling')

if isinstance(value, tuple):
setter(*value)
else:
setter(value)
return b''

def _get(self, path: str) -> bytes:
"""Run a `get` command in sysctl."""
value = self.root.get(path)
getter = self.root.get_getter(path)
value = getter()
return self.serialize(value).encode('utf-8')

def get_line_parts(self, line: str) -> tuple[str, ...]:
22 changes: 16 additions & 6 deletions hathor/sysctl/sysctl.py
Original file line number Diff line number Diff line change
@@ -25,6 +25,15 @@
logger = get_logger()


def signal_handler_safe(f):
"""Decorator to mark methods as signal handler safe.
It should only be used if that method can be executed during a signal handling.
Notice that a signal handling can pause the code execution at any point and the execution will resume after."""
f._signal_handler_safe = True
return f


class SysctlCommand(NamedTuple):
getter: Optional[Getter]
setter: Optional[Setter]
@@ -64,14 +73,14 @@ def get_command(self, path: str) -> SysctlCommand:
return child.get_command(tail)
raise SysctlEntryNotFound(path)

def _get_getter(self, path: str) -> Getter:
def get_getter(self, path: str) -> Getter:
"""Return the getter method of a path."""
cmd = self.get_command(path)
if cmd.getter is None:
raise SysctlWriteOnlyEntry(path)
return cmd.getter

def _get_setter(self, path: str) -> Setter:
def get_setter(self, path: str) -> Setter:
"""Return the setter method of a path."""
cmd = self.get_command(path)
if cmd.setter is None:
@@ -80,12 +89,13 @@ def _get_setter(self, path: str) -> Setter:

def get(self, path: str) -> Any:
"""Run a get in sysctl."""
getter = self._get_getter(path)
getter = self.get_getter(path)
return getter()

def set(self, path: str, value: Any) -> None:
"""Run a set in sysctl."""
setter = self._get_setter(path)
def unsafe_set(self, path: str, value: Any) -> None:
"""Run a set in sysctl. You should use a runner instead of calling this method directly.
Should not be called unless you know it's safe."""
setter = self.get_setter(path)
if isinstance(value, tuple):
setter(*value)
else:
Loading

0 comments on commit 565377f

Please sign in to comment.