Skip to content

Commit

Permalink
feat(cli): Use SIGUSR2 to run sysctl commands
Browse files Browse the repository at this point in the history
  • Loading branch information
msbrogli committed Feb 9, 2024
1 parent 3ce172f commit 8449772
Showing 1 changed file with 101 additions and 9 deletions.
110 changes: 101 additions & 9 deletions hathor/cli/run_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import os
import sys
import tempfile
from argparse import SUPPRESS, ArgumentParser, Namespace
from typing import TYPE_CHECKING, Any, Callable, Optional
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional

from pydantic import ValidationError
from structlog import get_logger
Expand All @@ -25,6 +27,21 @@

if TYPE_CHECKING:
from hathor.cli.run_node_args import RunNodeArgs
from hathor.sysctl.runner import SysctlRunner


@contextmanager
def temp_fifo(filename: str, tempdir: str | None) -> Iterator[str]:
"""Context Manager for creating named pipes."""
os.mkfifo(filename, mode=0o666)
fp = open(filename, 'r')
try:
yield fp
finally:
fp.close()
os.unlink(filename)
if tempdir is not None:
os.rmdir(tempdir)


class RunNode:
Expand Down Expand Up @@ -232,12 +249,81 @@ def register_signal_handlers(self) -> None:
if sigusr1 is not None:
# USR1 is available in this OS.
signal.signal(sigusr1, self.signal_usr1_handler)
sigusr2 = getattr(signal, 'SIGUSR2', None)
if sigusr2 is not None:
# USR1 is available in this OS.
signal.signal(sigusr2, self.signal_usr2_handler)

def signal_usr1_handler(self, sig: int, frame: Any) -> None:
"""Called when USR1 signal is received."""
self.log.warn('USR1 received. Killing all connections...')
if self.manager and self.manager.connections:
self.manager.connections.disconnect_all_peers(force=True)
try:
self.log.warn('USR1 received. Killing all connections...')
if self.manager and self.manager.connections:
self.manager.connections.disconnect_all_peers(force=True)
except Exception:
# see: https://docs.python.org/3/library/signal.html#note-on-signal-handlers-and-exceptions
self.log.error('prevented exception from escaping the signal handler', exc_info=True)

def signal_usr2_handler(self, sig: int, frame: Any) -> None:
"""Called when USR2 signal is received."""
try:
self.log.warn('USR2 received.')
self.run_sysctl_from_signal()
except Exception:
# see: https://docs.python.org/3/library/signal.html#note-on-signal-handlers-and-exceptions
self.log.error('prevented exception from escaping the signal handler', exc_info=True)

def run_sysctl_from_signal(self) -> None:
"""Block the main loop, get commands from a named pipe and execute then using sysctl."""
from hathor.sysctl.exception import (
SysctlEntryNotFound,
SysctlException,
SysctlReadOnlyEntry,
SysctlRunnerException,
SysctlWriteOnlyEntry,
)

runner = self.get_sysctl_runner()

if self._args.data is not None:
basedir = self._args.data
tempdir = None
else:
basedir = tempfile.mkdtemp()
tempdir = basedir

filename = os.path.join(basedir, f'SIGUSR2-{os.getpid()}.pipe')
if os.path.exists(filename):
self.log.warn('[USR2] Pipe already exists.', pipe=filename)
return

with temp_fifo(filename, tempdir) as fp:
self.log.warn('[USR2] Waiting for commands...', pipe=filename)
fp = open(filename, 'r')
lines = fp.readlines()
fp.close()
for cmd in lines:
cmd = cmd.strip()
self.log.warn('[USR2] Command received ', cmd=cmd)

try:
output = runner.run(cmd)
self.log.warn('[USR2] Output', output=output)
except SysctlEntryNotFound:
path, _, _ = runner.get_line_parts(cmd)
self.log.warn('[USR2] Error', errmsg=f'{path} not found')
except SysctlReadOnlyEntry:
path, _, _ = runner.get_line_parts(cmd)
self.log.warn('[USR2] Error', errmsg=f'cannot write to {path}')
except SysctlWriteOnlyEntry:
path, _, _ = runner.get_line_parts(cmd)
self.log.warn('[USR2] Error', errmsg=f'cannot read from {path}')
except SysctlException as e:
self.log.warn('[USR2] Error', errmsg=str(e))
except ValidationError as e:
self.log.warn('[USR2] Error', errmsg=str(e))
except SysctlRunnerException as e:
self.log.warn('[USR2] Error', errmsg=str(e))

def check_unsafe_arguments(self) -> None:
unsafe_args_found = []
Expand Down Expand Up @@ -386,6 +472,16 @@ def __init__(self, *, argv=None):
if self._args.sysctl:
self.init_sysctl(self._args.sysctl, self._args.sysctl_init_file)

def get_sysctl_runner(self) -> 'SysctlRunner':
"""Create and return a SysctlRunner."""
from hathor.builder.sysctl_builder import SysctlBuilder
from hathor.sysctl.runner import SysctlRunner

builder = SysctlBuilder(self.artifacts)
root = builder.build()
runner = SysctlRunner(root)
return runner

def init_sysctl(self, description: str, sysctl_init_file: Optional[str] = None) -> None:
"""Initialize sysctl, listen for connections and apply settings from config file if required.
Expand All @@ -400,14 +496,10 @@ def init_sysctl(self, description: str, sysctl_init_file: Optional[str] = None)
"""
from twisted.internet.endpoints import serverFromString

from hathor.builder.sysctl_builder import SysctlBuilder
from hathor.sysctl.factory import SysctlFactory
from hathor.sysctl.init_file_loader import SysctlInitFileLoader
from hathor.sysctl.runner import SysctlRunner

builder = SysctlBuilder(self.artifacts)
root = builder.build()
runner = SysctlRunner(root)
runner = self.get_sysctl_runner()

if sysctl_init_file:
init_file_loader = SysctlInitFileLoader(runner, sysctl_init_file)
Expand Down

0 comments on commit 8449772

Please sign in to comment.