From c5704300be7bcc86ec66490c203d909f80f24bd6 Mon Sep 17 00:00:00 2001 From: Olivier Cervello Date: Fri, 12 Apr 2024 09:42:31 -0400 Subject: [PATCH] update --- secator/utils.py | 238 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 162 insertions(+), 76 deletions(-) diff --git a/secator/utils.py b/secator/utils.py index bedde9a5..50c3ab81 100644 --- a/secator/utils.py +++ b/secator/utils.py @@ -30,18 +30,22 @@ class TaskError(ValueError): + """ + A custom exception class for task-related errors. + """ pass def setup_logging(level): - """Setup logging. + """ + Setup logging. - Args: - level: logging level. + Args: + level (int): logging level. - Returns: - logging.Logger: logger. - """ + Returns: + logging.Logger: logger. + """ logger = logging.getLogger('secator') logger.setLevel(level) ch = logging.StreamHandler() @@ -53,17 +57,18 @@ def setup_logging(level): def expand_input(input): - """Expand user-provided input on the CLI: - - If input is a path, read the file and return the lines. - - If it's a comma-separated list, return the list. - - Otherwise, return the original input. + """ + Expand user-provided input on the CLI: + - If input is a path, read the file and return the lines. + - If it's a comma-separated list, return the list. + - Otherwise, return the original input. - Args: - input (str): Input. + Args: + input (str): Input. - Returns: - str: Input. - """ + Returns: + str: Input. + """ if input is None: # read from stdin console.print('Waiting for input on stdin ...', style='bold yellow') rlist, _, _ = select.select([sys.stdin], [], [], DEFAULT_STDIN_TIMEOUT) @@ -93,14 +98,15 @@ def expand_input(input): def sanitize_url(http_url): - """Removes HTTP(s) ports 80 and 443 from HTTP(s) URL because it's ugly. + """ + Removes HTTP(s) ports 80 and 443 from HTTP(s) URL because it's ugly. - Args: - http_url (str): Input HTTP URL. + Args: + http_url (str): Input HTTP URL. - Returns: - str: Stripped HTTP URL. - """ + Returns: + str: Stripped HTTP URL. + """ url = urlparse(http_url) if url.netloc.endswith(':80'): url = url._replace(netloc=url.netloc.replace(':80', '')) @@ -110,14 +116,15 @@ def sanitize_url(http_url): def deduplicate(array, attr=None): - """Deduplicate list of OutputType items. + """ + Deduplicate list of OutputType items. - Args: - array (list): Input list. + Args: + array (list): Input list. - Returns: - list: Deduplicated list. - """ + Returns: + list: Deduplicated list. + """ from secator.output_types import OUTPUT_TYPES if attr and len(array) > 0 and isinstance(array[0], tuple(OUTPUT_TYPES)): memo = set() @@ -131,7 +138,9 @@ def deduplicate(array, attr=None): def discover_internal_tasks(): - """Find internal secator tasks.""" + """ + Find internal secator tasks. + """ from secator.runners import Runner package_dir = Path(__file__).resolve().parent / 'tasks' task_classes = [] @@ -160,7 +169,9 @@ def discover_internal_tasks(): def discover_external_tasks(): - """Find external secator tasks.""" + """ + Find external secator tasks. + """ if not os.path.exists('config.secator'): return [] with open('config.secator', 'r') as f: @@ -176,7 +187,9 @@ def discover_external_tasks(): def discover_tasks(): - """Find all secator tasks (internal + external).""" + """ + Find all secator tasks (internal + external). + """ global _tasks if not _tasks: _tasks = discover_internal_tasks() + discover_external_tasks() @@ -184,15 +197,16 @@ def discover_tasks(): def import_dynamic(cls_path, cls_root='Command'): - """Import class dynamically from class path. + """ + Import class dynamically from class path. - Args: - cls_path (str): Class path. - cls_root (str): Root parent class. + Args: + cls_path (str): Class path. + cls_root (str): Root parent class. - Returns: - cls: Class object. - """ + Returns: + cls: Class object. + """ try: package, name = cls_path.rsplit(".", maxsplit=1) cls = getattr(importlib.import_module(package), name) @@ -206,14 +220,15 @@ def import_dynamic(cls_path, cls_root='Command'): def get_command_cls(cls_name): - """Get secator command by class name. + """ + Get secator command by class name. - Args: - cls_name (str): Class name to load. + Args: + cls_name (str): Class name to load. - Returns: - cls: Class. - """ + Returns: + cls: Class. + """ tasks_classes = discover_tasks() for task_cls in tasks_classes: if task_cls.__name__ == cls_name: @@ -222,28 +237,30 @@ def get_command_cls(cls_name): def get_command_category(command): - """Get the category of a command. + """ + Get the category of a command. - Args: - command (class): Command class. + Args: + command (class): Command class. - Returns: - str: Command category. - """ + Returns: + str: Command category. + """ base_cls = command.__bases__[0].__name__.replace('Command', '').replace('Runner', 'misc') category = re.sub(r'(? 0 and isinstance(array[0], list): return list(itertools.chain(*array)) return array def pluralize(word): - """Pluralize a word. + """ + Pluralize a word. - Args: - word (string): Word. + Args: + word (str): Word. - Returns: - string: Plural word. - """ + Returns: + str: Plural word. + """ if word.endswith('y'): return word.rstrip('y') + 'ies' else: @@ -282,6 +301,21 @@ def pluralize(word): def load_fixture(name, fixtures_dir, ext=None, only_path=False): + """ + Load a fixture file based on the name and extension. + + Parameters: + name (str): The name of the fixture file. + fixtures_dir (str): The directory where the fixture files are located. + ext (str, optional): The extension of the fixture file. If not provided, will try with default extensions ['.json', '.txt', '.xml', '.rc']. + only_path (bool, optional): If True, only the path to the file will be returned, without reading the content. + + Returns: + str or dict: The content of the fixture file if it's a text file, or the parsed content if it's a JSON or YAML file. + + Raises: + None + """ fixture_path = f'{fixtures_dir}/{name}' exts = ['.json', '.txt', '.xml', '.rc'] if ext: @@ -300,10 +334,28 @@ def load_fixture(name, fixtures_dir, ext=None, only_path=False): def get_file_timestamp(): + """ + Get the current timestamp in a specific format. + + Returns: + str: The timestamp string in the format 'YYYY_MM_DD-HH_MM_SS_ffffff_AM/PM'. + """ return datetime.now().strftime("%Y_%m_%d-%I_%M_%S_%f_%p") def detect_host(interface=None): + """ + Detect the IP address of the host machine on a specific network interface. + + Args: + interface (str, optional): The name of the network interface to detect the IP address on. Defaults to None. + + Returns: + str: The IP address of the host machine on the specified network interface, or None if not found. + + Raises: + N/A + """ adapters = ifaddr.get_adapters() for adapter in adapters: iface = adapter.name @@ -314,10 +366,37 @@ def detect_host(interface=None): def find_list_item(array, val, key='id', default=None): + """ + Find and return an item in a list based on a specified key and value. + + Args: + array (list): A list of dictionaries. + val: The value to search for in the 'key' field. + key (str): The key field to search for the value (default is 'id'). + default: The default value to return if item is not found (default is None). + + Returns: + dict: The dictionary item from the list that matches the key and value provided. + """ return next((item for item in array if item[key] == val), default) def print_results_table(results, title=None, exclude_fields=[], log=False): + """ + Print a results table with data from a list of results. + + Args: + results (list): A list of results to be included in the table. + title (str, optional): The title of the results table. Defaults to None. + exclude_fields (list, optional): A list of fields to exclude from the table. Defaults to []. + log (bool, optional): Whether to log the results or not. Defaults to False. + + Returns: + list: A list of tables containing the formatted results. + + Raises: + None + """ from secator.output_types import OUTPUT_TYPES from secator.rich import build_table _print = console.log if log else console.print @@ -349,7 +428,9 @@ def print_results_table(results, title=None, exclude_fields=[], log=False): def rich_to_ansi(text): - """Convert text formatted with rich markup to standard string.""" + """ + Convert text formatted with rich markup to standard string. + """ from rich.console import Console tmp_console = Console(file=None, highlight=False, color_system='truecolor') with tmp_console.capture() as capture: @@ -358,7 +439,9 @@ def rich_to_ansi(text): def debug(msg, sub='', id='', obj=None, obj_after=True, obj_breaklines=False, level=1): - """Print debug log if DEBUG >= level.""" + """ + Print debug log if DEBUG is greater than or equal to level. + """ debug_comp_empty = DEBUG_COMPONENT == [""] or not DEBUG_COMPONENT if not debug_comp_empty and not any(sub.startswith(s) for s in DEBUG_COMPONENT): return @@ -389,14 +472,15 @@ def debug(msg, sub='', id='', obj=None, obj_after=True, obj_breaklines=False, le def escape_mongodb_url(url): - """Escape username / password from MongoDB URL if any. + """ + Escape username / password from MongoDB URL if any. - Args: - url (str): Full MongoDB URL string. + Args: + url (str): Full MongoDB URL string. - Returns: - str: Escaped MongoDB URL string. - """ + Returns: + str: Escaped MongoDB URL string. + """ match = re.search('mongodb://(?P.*)@(?P.*)', url) if match: url = match.group('url') @@ -407,7 +491,9 @@ def escape_mongodb_url(url): def print_version(): - """Print secator version information.""" + """ + Print secator version information. + """ from secator.installer import get_version_info console.print(f'[bold gold3]Current version[/]: {VERSION}', highlight=False, end='') info = get_version_info('secator', github_handle='freelabz/secator', version=VERSION)