diff --git a/.circleci/config.yml b/.circleci/config.yml index 1cc7265a3..520d70450 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -351,10 +351,16 @@ jobs: - install_dependencies - run: - name: run flake8 + name: run ruff command: | . venv/bin/activate - flake8 irrd + ruff irrd + + - run: + name: run isort + command: | + . venv/bin/activate + isort --check --diff irrd - run: name: run mypy @@ -385,6 +391,15 @@ jobs: - restore_cache: keys: v1-docs-cache + # Sphinx_immaterial does parallel downloads that don't work + # well on CircleCI. There is no setting, so we just edit + # the max number of workers. Bit gross, but this isn't our + # production doc build anyways, that happens on RTD without hacks. + - run: + name: hack sphinx_immaterial + command: | + sed -i 's/max_workers=32/max_workers=1/' /mnt/ramdisk/venv/lib/py*/site-packages/sphinx_immaterial/google_fonts.py + - run: name: build docs command: | diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 000000000..c9d083106 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,3 @@ +[settings] +profile=black +py_version=38 diff --git a/irrd/__init__.py b/irrd/__init__.py index 5b4927b89..35f089f1e 100644 --- a/irrd/__init__.py +++ b/irrd/__init__.py @@ -1,2 +1,2 @@ -__version__ = '4.4-dev' -ENV_MAIN_PROCESS_PID = 'IRRD_MAIN_PROCESS_PID' +__version__ = "4.4-dev" +ENV_MAIN_PROCESS_PID = "IRRD_MAIN_PROCESS_PID" diff --git a/irrd/conf/__init__.py b/irrd/conf/__init__.py index cfab8b308..36ae4b87c 100644 --- a/irrd/conf/__init__.py +++ b/irrd/conf/__init__.py @@ -13,52 +13,47 @@ from irrd.vendor.dotted.collection import DottedDict -CONFIG_PATH_DEFAULT = '/etc/irrd.yaml' +CONFIG_PATH_DEFAULT = "/etc/irrd.yaml" logger = logging.getLogger(__name__) -PASSWORD_HASH_DUMMY_VALUE = 'DummyValue' -SOURCE_NAME_RE = re.compile('^[A-Z][A-Z0-9-]*[A-Z0-9]$') -RPKI_IRR_PSEUDO_SOURCE = 'RPKI' +PASSWORD_HASH_DUMMY_VALUE = "DummyValue" +SOURCE_NAME_RE = re.compile("^[A-Z][A-Z0-9-]*[A-Z0-9]$") +RPKI_IRR_PSEUDO_SOURCE = "RPKI" ROUTEPREF_IMPORT_TIME = 3600 -AUTH_SET_CREATION_COMMON_KEY = 'COMMON' +AUTH_SET_CREATION_COMMON_KEY = "COMMON" SOCKET_DEFAULT_TIMEOUT = 30 LOGGING = { - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': { - 'verbose': { - 'format': '%(asctime)s irrd[%(process)d]: [%(name)s#%(levelname)s] %(message)s' - }, + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "verbose": {"format": "%(asctime)s irrd[%(process)d]: [%(name)s#%(levelname)s] %(message)s"}, }, - 'handlers': { - 'console': { - 'class': 'logging.StreamHandler', - 'formatter': 'verbose' - }, + "handlers": { + "console": {"class": "logging.StreamHandler", "formatter": "verbose"}, }, - 'loggers': { + "loggers": { # Tune down some very loud and not very useful loggers from libraries. - 'passlib.registry': { - 'level': 'INFO', + "passlib.registry": { + "level": "INFO", }, - 'gnupg': { - 'level': 'INFO', + "gnupg": { + "level": "INFO", }, # Must be specified explicitly to disable tracing middleware, # which adds substantial overhead - 'uvicorn.error': { - 'level': 'INFO', + "uvicorn.error": { + "level": "INFO", }, - 'sqlalchemy': { - 'level': 'WARNING', + "sqlalchemy": { + "level": "WARNING", }, - '': { - 'handlers': ['console'], - 'level': 'INFO', + "": { + "handlers": ["console"], + "level": "INFO", }, - } + }, } @@ -81,33 +76,35 @@ class Configuration: The Configuration class stores the current IRRD configuration, checks the validity of the settings, and offers graceful reloads. """ + user_config_staging: DottedDict user_config_live: DottedDict - def __init__(self, user_config_path: Optional[str]=None, commit=True): + def __init__(self, user_config_path: Optional[str] = None, commit=True): """ Load the default config and load and check the user provided config. If a logfile was specified, direct logs there. """ from .known_keys import KNOWN_CONFIG_KEYS, KNOWN_SOURCES_KEYS + self.known_config_keys = KNOWN_CONFIG_KEYS self.known_sources_keys = KNOWN_SOURCES_KEYS self.user_config_path = user_config_path if user_config_path else CONFIG_PATH_DEFAULT - default_config_path = str(Path(__file__).resolve().parents[0] / 'default_config.yaml') + default_config_path = str(Path(__file__).resolve().parents[0] / "default_config.yaml") with open(default_config_path) as default_config: default_config_yaml = yaml.safe_load(default_config) - self.default_config = DottedDict(default_config_yaml['irrd']) + self.default_config = DottedDict(default_config_yaml["irrd"]) self.logging_config = LOGGING errors = self._staging_reload_check(log_success=False) if errors: - raise ConfigurationError(f'Errors found in configuration, unable to start: {errors}') + raise ConfigurationError(f"Errors found in configuration, unable to start: {errors}") if commit: self._commit_staging() - logging_config_path = self.get_setting_live('log.logging_config_path') - logfile_path = self.get_setting_live('log.logfile_path') + logging_config_path = self.get_setting_live("log.logging_config_path") + logfile_path = self.get_setting_live("log.logfile_path") if logging_config_path: spec = importlib.util.spec_from_file_location("logging_config", logging_config_path) config_module = importlib.util.module_from_spec(spec) # type: ignore @@ -115,19 +112,19 @@ def __init__(self, user_config_path: Optional[str]=None, commit=True): self.logging_config = config_module.LOGGING # type: ignore logging.config.dictConfig(self.logging_config) elif logfile_path: - LOGGING['handlers']['file'] = { # type:ignore - 'class': 'logging.handlers.WatchedFileHandler', - 'filename': logfile_path, - 'formatter': 'verbose', + LOGGING["handlers"]["file"] = { # type:ignore + "class": "logging.handlers.WatchedFileHandler", + "filename": logfile_path, + "formatter": "verbose", } # noinspection PyTypeChecker - LOGGING['loggers']['']['handlers'] = ['file'] # type:ignore + LOGGING["loggers"][""]["handlers"] = ["file"] # type:ignore logging.config.dictConfig(LOGGING) # Re-commit to apply loglevel self._commit_staging() - def get_setting_live(self, setting_name: str, default: Optional[Any]=None) -> Any: + def get_setting_live(self, setting_name: str, default: Optional[Any] = None) -> Any: """ Get a setting from the live config. In order, this will look in: @@ -140,15 +137,15 @@ def get_setting_live(self, setting_name: str, default: Optional[Any]=None) -> An If it is not found in any, the value of the default paramater is returned, which is None by default. """ - if setting_name.startswith('sources'): - components = setting_name.split('.') + if setting_name.startswith("sources"): + components = setting_name.split(".") if len(components) == 3 and components[2] not in self.known_sources_keys: - raise ValueError(f'Unknown setting {setting_name}') - elif not setting_name.startswith('access_lists'): + raise ValueError(f"Unknown setting {setting_name}") + elif not setting_name.startswith("access_lists"): if self.known_config_keys.get(setting_name) is None: - raise ValueError(f'Unknown setting {setting_name}') + raise ValueError(f"Unknown setting {setting_name}") - env_key = 'IRRD_' + setting_name.upper().replace('.', '_') + env_key = "IRRD_" + setting_name.upper().replace(".", "_") if env_key in os.environ: return os.environ[env_key] if testing_overrides: @@ -167,7 +164,7 @@ def reload(self) -> bool: """ errors = self._staging_reload_check() if errors: - logger.error(f'Errors found in configuration, continuing with current settings: {errors}') + logger.error(f"Errors found in configuration, continuing with current settings: {errors}") return False self._commit_staging() @@ -178,9 +175,9 @@ def _commit_staging(self) -> None: Activate the current staging config as the live config. """ self.user_config_live = self.user_config_staging - logging.getLogger('').setLevel(self.get_setting_live('log.level', default='INFO')) - if hasattr(sys, '_called_from_test'): - logging.getLogger('').setLevel('DEBUG') + logging.getLogger("").setLevel(self.get_setting_live("log.level", default="INFO")) + if hasattr(sys, "_called_from_test"): + logging.getLogger("").setLevel("DEBUG") def _staging_reload_check(self, log_success=True) -> List[str]: """ @@ -191,7 +188,7 @@ def _staging_reload_check(self, log_success=True) -> List[str]: # While in testing, Configuration does not demand a valid config file # This simplifies test setup, as most tests do not need it. # If a non-default path is set during testing, it is still checked. - if hasattr(sys, '_called_from_test') and self.user_config_path == CONFIG_PATH_DEFAULT: + if hasattr(sys, "_called_from_test") and self.user_config_path == CONFIG_PATH_DEFAULT: self.user_config_staging = DottedDict({}) return [] @@ -199,17 +196,19 @@ def _staging_reload_check(self, log_success=True) -> List[str]: with open(self.user_config_path) as fh: user_config_yaml = yaml.safe_load(fh) except OSError as oe: - return [f'Error opening config file {self.user_config_path}: {oe}'] + return [f"Error opening config file {self.user_config_path}: {oe}"] except yaml.YAMLError as ye: - return [f'Error parsing YAML file: {ye}'] + return [f"Error parsing YAML file: {ye}"] - if not isinstance(user_config_yaml, dict) or 'irrd' not in user_config_yaml: + if not isinstance(user_config_yaml, dict) or "irrd" not in user_config_yaml: return [f'Could not find root item "irrd" in config file {self.user_config_path}'] - self.user_config_staging = DottedDict(user_config_yaml['irrd']) + self.user_config_staging = DottedDict(user_config_yaml["irrd"]) errors = self._check_staging_config() if not errors and log_success: - logger.info(f'Configuration successfully (re)loaded from {self.user_config_path} in PID {os.getpid()}') + logger.info( + f"Configuration successfully (re)loaded from {self.user_config_path} in PID {os.getpid()}" + ) return errors def _check_staging_config(self) -> List[str]: @@ -223,172 +222,224 @@ def _check_staging_config(self) -> List[str]: def _validate_subconfig(key, value): if isinstance(value, (DottedDict, dict)): for key2, value2 in value.items(): - subkey = key + '.' + key2 + subkey = key + "." + key2 known_sub = self.known_config_keys.get(subkey) if known_sub is None: - errors.append(f'Unknown setting key: {subkey}') + errors.append(f"Unknown setting key: {subkey}") _validate_subconfig(subkey, value2) for key, value in config.items(): - if key in ['sources', 'access_lists']: + if key in ["sources", "access_lists"]: continue if self.known_config_keys.get(key) is None: - errors.append(f'Unknown setting key: {key}') + errors.append(f"Unknown setting key: {key}") _validate_subconfig(key, value) - if not self._check_is_str(config, 'database_url'): - errors.append('Setting database_url is required.') + if not self._check_is_str(config, "database_url"): + errors.append("Setting database_url is required.") - if not self._check_is_str(config, 'redis_url'): - errors.append('Setting redis_url is required.') + if not self._check_is_str(config, "redis_url"): + errors.append("Setting redis_url is required.") - if not self._check_is_str(config, 'piddir') or not os.path.isdir(config['piddir']): - errors.append('Setting piddir is required and must point to an existing directory.') + if not self._check_is_str(config, "piddir") or not os.path.isdir(config["piddir"]): + errors.append("Setting piddir is required and must point to an existing directory.") - if not str(config.get('route_object_preference.update_timer', '0')).isnumeric(): - errors.append('Setting route_object_preference.update_timer must be a number.') + if not str(config.get("route_object_preference.update_timer", "0")).isnumeric(): + errors.append("Setting route_object_preference.update_timer must be a number.") expected_access_lists = { - config.get('server.whois.access_list'), - config.get('server.http.status_access_list'), + config.get("server.whois.access_list"), + config.get("server.http.status_access_list"), } - if not self._check_is_str(config, 'email.from') or '@' not in config.get('email.from', ''): - errors.append('Setting email.from is required and must be an email address.') - if not self._check_is_str(config, 'email.smtp'): - errors.append('Setting email.smtp is required.') - if not self._check_is_str(config, 'email.recipient_override', required=False) \ - or '@' not in config.get('email.recipient_override', '@'): - errors.append('Setting email.recipient_override must be an email address if set.') - - string_not_required = ['email.footer', 'server.whois.access_list', - 'server.http.status_access_list', 'rpki.notify_invalid_subject', - 'rpki.notify_invalid_header', 'rpki.slurm_source', 'user', 'group'] + if not self._check_is_str(config, "email.from") or "@" not in config.get("email.from", ""): + errors.append("Setting email.from is required and must be an email address.") + if not self._check_is_str(config, "email.smtp"): + errors.append("Setting email.smtp is required.") + if not self._check_is_str( + config, "email.recipient_override", required=False + ) or "@" not in config.get("email.recipient_override", "@"): + errors.append("Setting email.recipient_override must be an email address if set.") + + string_not_required = [ + "email.footer", + "server.whois.access_list", + "server.http.status_access_list", + "rpki.notify_invalid_subject", + "rpki.notify_invalid_header", + "rpki.slurm_source", + "user", + "group", + ] for setting in string_not_required: if not self._check_is_str(config, setting, required=False): - errors.append(f'Setting {setting} must be a string, if defined.') + errors.append(f"Setting {setting} must be a string, if defined.") - if bool(config.get('user')) != bool(config.get('group')): - errors.append('Settings user and group must both be defined, or neither.') + if bool(config.get("user")) != bool(config.get("group")): + errors.append("Settings user and group must both be defined, or neither.") - if not self._check_is_str(config, 'auth.gnupg_keyring'): - errors.append('Setting auth.gnupg_keyring is required.') + if not self._check_is_str(config, "auth.gnupg_keyring"): + errors.append("Setting auth.gnupg_keyring is required.") from irrd.updates.parser_state import RPSLSetAutnumAuthenticationMode + valid_auth = [mode.value for mode in RPSLSetAutnumAuthenticationMode] - for set_name, params in config.get('auth.set_creation', {}).items(): - if not isinstance(params.get('prefix_required', False), bool): - errors.append(f'Setting auth.set_creation.{set_name}.prefix_required must be a bool') - if params.get('autnum_authentication') and params['autnum_authentication'].lower() not in valid_auth: - errors.append(f'Setting auth.set_creation.{set_name}.autnum_authentication must be one of {valid_auth} if set') + for set_name, params in config.get("auth.set_creation", {}).items(): + if not isinstance(params.get("prefix_required", False), bool): + errors.append(f"Setting auth.set_creation.{set_name}.prefix_required must be a bool") + if ( + params.get("autnum_authentication") + and params["autnum_authentication"].lower() not in valid_auth + ): + errors.append( + f"Setting auth.set_creation.{set_name}.autnum_authentication must be one of" + f" {valid_auth} if set" + ) from irrd.rpsl.passwords import PasswordHasherAvailability + valid_hasher_availability = [avl.value for avl in PasswordHasherAvailability] - for hasher_name, setting in config.get('auth.password_hashers', {}).items(): + for hasher_name, setting in config.get("auth.password_hashers", {}).items(): if setting.lower() not in valid_hasher_availability: - errors.append(f'Setting auth.password_hashers.{hasher_name} must be one of {valid_hasher_availability}') + errors.append( + f"Setting auth.password_hashers.{hasher_name} must be one of {valid_hasher_availability}" + ) - for name, access_list in config.get('access_lists', {}).items(): + for name, access_list in config.get("access_lists", {}).items(): for item in access_list: try: IP(item) except ValueError as ve: - errors.append(f'Invalid item in access list {name}: {ve}.') + errors.append(f"Invalid item in access list {name}: {ve}.") - for prefix in config.get('scopefilter.prefixes', []): + for prefix in config.get("scopefilter.prefixes", []): try: IP(prefix) except ValueError as ve: - errors.append(f'Invalid item in prefix scopefilter: {prefix}: {ve}.') + errors.append(f"Invalid item in prefix scopefilter: {prefix}: {ve}.") - for asn in config.get('scopefilter.asns', []): + for asn in config.get("scopefilter.asns", []): try: - if '-' in str(asn): - first, last = asn.split('-') + if "-" in str(asn): + first, last = asn.split("-") int(first) int(last) else: int(asn) except ValueError: - errors.append(f'Invalid item in asn scopefilter: {asn}.') + errors.append(f"Invalid item in asn scopefilter: {asn}.") - known_sources = set(config.get('sources', {}).keys()) + known_sources = set(config.get("sources", {}).keys()) has_authoritative_sources = False - for name, details in config.get('sources', {}).items(): + for name, details in config.get("sources", {}).items(): unknown_keys = set(details.keys()) - self.known_sources_keys if unknown_keys: errors.append(f'Unknown key(s) under source {name}: {", ".join(unknown_keys)}') - if details.get('authoritative'): + if details.get("authoritative"): has_authoritative_sources = True - if config.get('rpki.roa_source') and name == RPKI_IRR_PSEUDO_SOURCE: - errors.append(f'Setting sources contains reserved source name: {RPKI_IRR_PSEUDO_SOURCE}') + if config.get("rpki.roa_source") and name == RPKI_IRR_PSEUDO_SOURCE: + errors.append(f"Setting sources contains reserved source name: {RPKI_IRR_PSEUDO_SOURCE}") if not SOURCE_NAME_RE.match(name): - errors.append(f'Invalid source name: {name}') - - if details.get('suspension_enabled') and not details.get('authoritative'): - errors.append(f'Setting suspension_enabled for source {name} can not be enabled without enabling ' - f'authoritative.') - - nrtm_mirror = details.get('nrtm_host') and details.get('import_serial_source') - if details.get('keep_journal') and not (nrtm_mirror or details.get('authoritative')): - errors.append(f'Setting keep_journal for source {name} can not be enabled unless either authoritative ' - f'is enabled, or all three of nrtm_host, nrtm_port and import_serial_source.') - if details.get('nrtm_host') and not details.get('import_serial_source'): - errors.append(f'Setting nrtm_host for source {name} can not be enabled without setting ' - f'import_serial_source.') - - if details.get('authoritative') and (details.get('nrtm_host') or details.get('import_source')): - errors.append(f'Setting authoritative for source {name} can not be enabled when either ' - f'nrtm_host or import_source are set.') - - if config.get('database_readonly') and (details.get('authoritative') or details.get('nrtm_host') or details.get('import_source')): - errors.append(f'Source {name} can not have authoritative, import_source or nrtm_host set ' - f'when database_readonly is enabled.') + errors.append(f"Invalid source name: {name}") + + if details.get("suspension_enabled") and not details.get("authoritative"): + errors.append( + f"Setting suspension_enabled for source {name} can not be enabled without enabling " + "authoritative." + ) + + nrtm_mirror = details.get("nrtm_host") and details.get("import_serial_source") + if details.get("keep_journal") and not (nrtm_mirror or details.get("authoritative")): + errors.append( + f"Setting keep_journal for source {name} can not be enabled unless either authoritative " + "is enabled, or all three of nrtm_host, nrtm_port and import_serial_source." + ) + if details.get("nrtm_host") and not details.get("import_serial_source"): + errors.append( + f"Setting nrtm_host for source {name} can not be enabled without setting " + "import_serial_source." + ) + + if details.get("authoritative") and (details.get("nrtm_host") or details.get("import_source")): + errors.append( + f"Setting authoritative for source {name} can not be enabled when either " + "nrtm_host or import_source are set." + ) + + if config.get("database_readonly") and ( + details.get("authoritative") or details.get("nrtm_host") or details.get("import_source") + ): + errors.append( + f"Source {name} can not have authoritative, import_source or nrtm_host set " + "when database_readonly is enabled." + ) number_fields = [ - 'nrtm_port', 'import_timer', 'export_timer', - 'route_object_preference', 'nrtm_query_serial_range_limit', + "nrtm_port", + "import_timer", + "export_timer", + "route_object_preference", + "nrtm_query_serial_range_limit", ] for field_name in number_fields: if not str(details.get(field_name, 0)).isnumeric(): - errors.append(f'Setting {field_name} for source {name} must be a number.') + errors.append(f"Setting {field_name} for source {name} must be a number.") - if details.get('nrtm_access_list'): - expected_access_lists.add(details.get('nrtm_access_list')) - if details.get('nrtm_access_list_unfiltered'): - expected_access_lists.add(details.get('nrtm_access_list_unfiltered')) + if details.get("nrtm_access_list"): + expected_access_lists.add(details.get("nrtm_access_list")) + if details.get("nrtm_access_list_unfiltered"): + expected_access_lists.add(details.get("nrtm_access_list_unfiltered")) - if config.get('rpki.roa_source', 'https://rpki.gin.ntt.net/api/export.json'): + if config.get("rpki.roa_source", "https://rpki.gin.ntt.net/api/export.json"): known_sources.add(RPKI_IRR_PSEUDO_SOURCE) - if has_authoritative_sources and config.get('rpki.notify_invalid_enabled') is None: - errors.append('RPKI-aware mode is enabled and authoritative sources are configured, ' - 'but rpki.notify_invalid_enabled is not set. Set to true or false.' - 'DANGER: care is required with this setting in testing setups with ' - 'live data, as it may send bulk emails to real resource contacts ' - 'unless email.recipient_override is also set. ' - 'Read documentation carefully.') - - unknown_default_sources = set(config.get('sources_default', [])).difference(known_sources) + if has_authoritative_sources and config.get("rpki.notify_invalid_enabled") is None: + errors.append( + "RPKI-aware mode is enabled and authoritative sources are configured, " + "but rpki.notify_invalid_enabled is not set. Set to true or false." + "DANGER: care is required with this setting in testing setups with " + "live data, as it may send bulk emails to real resource contacts " + "unless email.recipient_override is also set. " + "Read documentation carefully." + ) + + unknown_default_sources = set(config.get("sources_default", [])).difference(known_sources) if unknown_default_sources: - errors.append(f'Setting sources_default contains unknown sources: {", ".join(unknown_default_sources)}') - - if not str(config.get('rpki.roa_import_timer', '0')).isnumeric(): - errors.append('Setting rpki.roa_import_timer must be set to a number.') - - if config.get('log.level') and not config.get('log.level') in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']: - errors.append(f'Invalid log.level: {config.get("log.level")}. ' - f'Valid settings for log.level are `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`.') - if config.get('log.logging_config_path') and (config.get('log.logfile_path') or config.get('log.level')): - errors.append('Setting log.logging_config_path can not be combined with' - 'log.logfile_path or log.level') - - access_lists = set(config.get('access_lists', {}).keys()) - unresolved_access_lists = [x for x in expected_access_lists.difference(access_lists) if x and isinstance(x, str)] + errors.append( + f'Setting sources_default contains unknown sources: {", ".join(unknown_default_sources)}' + ) + + if not str(config.get("rpki.roa_import_timer", "0")).isnumeric(): + errors.append("Setting rpki.roa_import_timer must be set to a number.") + + if config.get("log.level") and config.get("log.level") not in [ + "DEBUG", + "INFO", + "WARNING", + "ERROR", + "CRITICAL", + ]: + errors.append( + f'Invalid log.level: {config.get("log.level")}. ' + "Valid settings for log.level are `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`." + ) + if config.get("log.logging_config_path") and ( + config.get("log.logfile_path") or config.get("log.level") + ): + errors.append( + "Setting log.logging_config_path can not be combined withlog.logfile_path or log.level" + ) + + access_lists = set(config.get("access_lists", {}).keys()) + unresolved_access_lists = [ + x for x in expected_access_lists.difference(access_lists) if x and isinstance(x, str) + ] unresolved_access_lists.sort() if unresolved_access_lists: - errors.append(f'Access lists {", ".join(unresolved_access_lists)} referenced in settings, but not defined.') + errors.append( + f'Access lists {", ".join(unresolved_access_lists)} referenced in settings, but not defined.' + ) return errors @@ -427,13 +478,13 @@ def is_config_initialised() -> bool: return configuration is not None -def get_setting(setting_name: str, default: Optional[Any]=None) -> Any: +def get_setting(setting_name: str, default: Optional[Any] = None) -> Any: """ Convenience wrapper to get the value of a setting. """ configuration = get_configuration() if not configuration: # pragma: no cover - raise Exception('get_setting() called before configuration was initialised') + raise Exception("get_setting() called before configuration was initialised") return configuration.get_setting_live(setting_name, default) diff --git a/irrd/conf/defaults.py b/irrd/conf/defaults.py index 9c6da5d17..319365a21 100644 --- a/irrd/conf/defaults.py +++ b/irrd/conf/defaults.py @@ -1,6 +1,6 @@ # In addition to these settings, simple # defaults are stored in default_config.yaml. -DEFAULT_SOURCE_NRTM_PORT = '43' +DEFAULT_SOURCE_NRTM_PORT = "43" DEFAULT_SOURCE_IMPORT_TIMER = 300 DEFAULT_SOURCE_EXPORT_TIMER = 3600 diff --git a/irrd/conf/known_keys.py b/irrd/conf/known_keys.py index 5d9a91df9..dc55def92 100644 --- a/irrd/conf/known_keys.py +++ b/irrd/conf/known_keys.py @@ -1,100 +1,101 @@ from irrd.conf import AUTH_SET_CREATION_COMMON_KEY from irrd.rpsl.passwords import PASSWORD_HASHERS_ALL -from irrd.vendor.dotted.collection import DottedDict from irrd.rpsl.rpsl_objects import OBJECT_CLASS_MAPPING, RPSLSet +from irrd.vendor.dotted.collection import DottedDict # Note that sources are checked separately, # and 'access_lists' is always permitted -KNOWN_CONFIG_KEYS = DottedDict({ - 'database_url': {}, - 'database_readonly': {}, - 'redis_url': {}, - 'piddir': {}, - 'user': {}, - 'group': {}, - 'server': { - 'http': { - 'interface': {}, - 'port': {}, - 'status_access_list': {}, - 'event_stream_access_list': {}, - 'workers': {}, - 'forwarded_allowed_ips': {}, +KNOWN_CONFIG_KEYS = DottedDict( + { + "database_url": {}, + "database_readonly": {}, + "redis_url": {}, + "piddir": {}, + "user": {}, + "group": {}, + "server": { + "http": { + "interface": {}, + "port": {}, + "status_access_list": {}, + "event_stream_access_list": {}, + "workers": {}, + "forwarded_allowed_ips": {}, + }, + "whois": { + "interface": {}, + "port": {}, + "access_list": {}, + "max_connections": {}, + }, + }, + "route_object_preference": {"update_timer": {}}, + "email": { + "from": {}, + "footer": {}, + "smtp": {}, + "recipient_override": {}, + "notification_header": {}, + }, + "auth": { + "override_password": {}, + "authenticate_parents_route_creation": {}, + "gnupg_keyring": {}, + "set_creation": { + rpsl_object_class: {"prefix_required": {}, "autnum_authentication": {}} + for rpsl_object_class in [ + set_object.rpsl_object_class + for set_object in OBJECT_CLASS_MAPPING.values() + if issubclass(set_object, RPSLSet) + ] + + [AUTH_SET_CREATION_COMMON_KEY] + }, + "password_hashers": {hasher_name.lower(): {} for hasher_name in PASSWORD_HASHERS_ALL.keys()}, + }, + "rpki": { + "roa_source": {}, + "roa_import_timer": {}, + "slurm_source": {}, + "pseudo_irr_remarks": {}, + "notify_invalid_enabled": {}, + "notify_invalid_subject": {}, + "notify_invalid_header": {}, + }, + "scopefilter": { + "prefixes": {}, + "asns": {}, }, - 'whois': { - 'interface': {}, - 'port': {}, - 'access_list': {}, - 'max_connections': {}, + "log": { + "logfile_path": {}, + "level": {}, + "logging_config_path": {}, }, - }, - 'route_object_preference': {'update_timer': {}}, - 'email': { - 'from': {}, - 'footer': {}, - 'smtp': {}, - 'recipient_override': {}, - 'notification_header': {}, - }, - 'auth': { - 'override_password': {}, - 'authenticate_parents_route_creation': {}, - 'gnupg_keyring': {}, - 'set_creation': { - rpsl_object_class: {'prefix_required': {}, 'autnum_authentication': {}} - for rpsl_object_class in [ - set_object.rpsl_object_class - for set_object in OBJECT_CLASS_MAPPING.values() - if issubclass(set_object, RPSLSet) - ] + [AUTH_SET_CREATION_COMMON_KEY] + "sources_default": {}, + "compatibility": { + "inetnum_search_disabled": {}, + "ipv4_only_route_set_members": {}, }, - 'password_hashers': { - hasher_name.lower(): {} for hasher_name in PASSWORD_HASHERS_ALL.keys() - } - }, - 'rpki': { - 'roa_source': {}, - 'roa_import_timer': {}, - 'slurm_source': {}, - 'pseudo_irr_remarks': {}, - 'notify_invalid_enabled': {}, - 'notify_invalid_subject': {}, - 'notify_invalid_header': {}, - }, - 'scopefilter': { - 'prefixes': {}, - 'asns': {}, - }, - 'log': { - 'logfile_path': {}, - 'level': {}, - 'logging_config_path': {}, - }, - 'sources_default': {}, - 'compatibility': { - 'inetnum_search_disabled': {}, - 'ipv4_only_route_set_members': {}, } -}) +) KNOWN_SOURCES_KEYS = { - 'authoritative', - 'keep_journal', - 'nrtm_host', - 'nrtm_port', - 'import_source', - 'import_serial_source', - 'import_timer', - 'object_class_filter', - 'export_destination', - 'export_destination_unfiltered', - 'export_timer', - 'nrtm_access_list', - 'nrtm_access_list_unfiltered', - 'nrtm_query_serial_range_limit', - 'strict_import_keycert_objects', - 'rpki_excluded', - 'scopefilter_excluded', - 'suspension_enabled', - 'route_object_preference', + "authoritative", + "keep_journal", + "nrtm_host", + "nrtm_port", + "import_source", + "import_serial_source", + "import_timer", + "object_class_filter", + "export_destination", + "export_destination_unfiltered", + "export_timer", + "nrtm_access_list", + "nrtm_access_list_unfiltered", + "nrtm_query_serial_range_limit", + "strict_import_keycert_objects", + "rpki_excluded", + "scopefilter_excluded", + "suspension_enabled", + "route_object_preference", } diff --git a/irrd/conf/test_conf.py b/irrd/conf/test_conf.py index b5e0540fd..a2a9b8b6c 100644 --- a/irrd/conf/test_conf.py +++ b/irrd/conf/test_conf.py @@ -1,44 +1,51 @@ import os +import signal import textwrap +from typing import Dict import pytest -import signal import yaml -from typing import Dict -from . import get_setting, ConfigurationError, config_init, is_config_initialised, get_configuration +from . import ( + ConfigurationError, + config_init, + get_configuration, + get_setting, + is_config_initialised, +) @pytest.fixture() def save_yaml_config(tmpdir, monkeypatch): def _save(config: Dict, run_init=True): - tmp_file = tmpdir + '/config.yaml' - with open(tmp_file, 'w') as fh: + tmp_file = tmpdir + "/config.yaml" + with open(tmp_file, "w") as fh: fh.write(yaml.safe_dump(config)) if run_init: config_init(str(tmp_file)) + return _save class TestConfiguration: def test_file_not_existing(self, monkeypatch, tmpdir): with pytest.raises(ConfigurationError) as ce: - config_init(str(tmpdir + '/doesnotexist.yaml')) - assert 'Error opening config file' in str(ce.value) + config_init(str(tmpdir + "/doesnotexist.yaml")) + assert "Error opening config file" in str(ce.value) def test_load_invalid_yaml(self, monkeypatch, tmpdir): - tmp_file = tmpdir + '/config.yaml' - fh = open(tmp_file, 'w') - fh.write(' >foo') + tmp_file = tmpdir + "/config.yaml" + fh = open(tmp_file, "w") + fh.write(" >foo") fh.close() with pytest.raises(ConfigurationError) as ce: config_init(str(tmp_file)) - assert 'Error parsing YAML file' in str(ce.value) + assert "Error parsing YAML file" in str(ce.value) def test_load_string_file(self, save_yaml_config): with pytest.raises(ConfigurationError) as ce: - save_yaml_config('foo') + save_yaml_config("foo") assert 'Could not find root item "irrd" in config file' in str(ce.value) def test_load_empty_config(self, save_yaml_config): @@ -47,100 +54,92 @@ def test_load_empty_config(self, save_yaml_config): assert 'Could not find root item "irrd" in config file' in str(ce.value) def test_load_valid_reload_valid_config(self, monkeypatch, save_yaml_config, tmpdir, caplog): - logfile = str(tmpdir + '/logfile.txt') + logfile = str(tmpdir + "/logfile.txt") config = { - 'irrd': { - 'database_url': 'db-url', - 'redis_url': 'redis-url', - 'piddir': str(tmpdir), - 'email': { - 'from': 'example@example.com', - 'smtp': '192.0.2.1' - }, - 'route_object_preference': { - 'update_timer': 10, - }, - 'rpki': { - 'roa_source': None, + "irrd": { + "database_url": "db-url", + "redis_url": "redis-url", + "piddir": str(tmpdir), + "email": {"from": "example@example.com", "smtp": "192.0.2.1"}, + "route_object_preference": { + "update_timer": 10, }, - 'scopefilter': { - 'prefixes': ['10/8'], - 'asns': ['23456', '10-20'] + "rpki": { + "roa_source": None, }, - 'access_lists': { - 'valid-list': { - '192/24', - '192.0.2.1', - '2001:db8::32', - '2001:db8::1', + "scopefilter": {"prefixes": ["10/8"], "asns": ["23456", "10-20"]}, + "access_lists": { + "valid-list": { + "192/24", + "192.0.2.1", + "2001:db8::32", + "2001:db8::1", } }, - 'auth': { - 'gnupg_keyring': str(tmpdir), - 'authenticate_parents_route_creation': True, - 'set_creation': { - 'as-set': { - 'prefix_required': True, - 'autnum_authentication': 'opportunistic', + "auth": { + "gnupg_keyring": str(tmpdir), + "authenticate_parents_route_creation": True, + "set_creation": { + "as-set": { + "prefix_required": True, + "autnum_authentication": "opportunistic", }, - 'COMMON': { - 'prefix_required': True, - 'autnum_authentication': 'required', + "COMMON": { + "prefix_required": True, + "autnum_authentication": "required", }, }, - 'password_hashers': { - 'bcrypt-pw': 'legacy', + "password_hashers": { + "bcrypt-pw": "legacy", }, }, - 'sources_default': ['TESTDB2', 'TESTDB'], - 'sources': { - 'TESTDB': { - 'authoritative': True, - 'keep_journal': True, - 'suspension_enabled': True, - 'nrtm_query_serial_range_limit': 10, + "sources_default": ["TESTDB2", "TESTDB"], + "sources": { + "TESTDB": { + "authoritative": True, + "keep_journal": True, + "suspension_enabled": True, + "nrtm_query_serial_range_limit": 10, }, - 'TESTDB2': { - 'nrtm_host': '192.0.2.1', - 'nrtm_port': 43, - 'import_serial_source': 'ftp://example.com/serial', - 'keep_journal': True, - 'route_object_preference': 200, + "TESTDB2": { + "nrtm_host": "192.0.2.1", + "nrtm_port": 43, + "import_serial_source": "ftp://example.com/serial", + "keep_journal": True, + "route_object_preference": 200, }, - 'TESTDB3': { - 'export_destination_unfiltered': '/tmp', - 'nrtm_access_list_unfiltered': 'valid-list', + "TESTDB3": { + "export_destination_unfiltered": "/tmp", + "nrtm_access_list_unfiltered": "valid-list", }, # RPKI source permitted, rpki.roa_source not set - 'RPKI': {}, + "RPKI": {}, }, - 'log': { - 'level': 'DEBUG', - 'logfile_path': logfile - }, - + "log": {"level": "DEBUG", "logfile_path": logfile}, } } save_yaml_config(config) assert is_config_initialised() - config['irrd']['sources_default'] = ['TESTDB2'] + config["irrd"]["sources_default"] = ["TESTDB2"] save_yaml_config(config, run_init=False) # Unchanged, no reload performed - assert list(get_setting('sources_default')) == ['TESTDB2', 'TESTDB'] + assert list(get_setting("sources_default")) == ["TESTDB2", "TESTDB"] os.kill(os.getpid(), signal.SIGHUP) - assert list(get_setting('sources_default')) == ['TESTDB2'] + assert list(get_setting("sources_default")) == ["TESTDB2"] logfile_contents = open(logfile).read() - assert 'Configuration successfully (re)loaded from ' in logfile_contents + assert "Configuration successfully (re)loaded from " in logfile_contents def test_load_custom_logging_config(self, monkeypatch, save_yaml_config, tmpdir, caplog): - logfile = str(tmpdir + '/logfile.txt') - logging_config_path = str(tmpdir + '/logging.py') - with open(logging_config_path, 'w') as fh: - fh.write(textwrap.dedent(""" + logfile = str(tmpdir + "/logfile.txt") + logging_config_path = str(tmpdir + "/logging.py") + with open(logging_config_path, "w") as fh: + fh.write( + textwrap.dedent( + """ LOGGING = { 'version': 1, 'disable_existing_loggers': False, @@ -152,7 +151,9 @@ def test_load_custom_logging_config(self, monkeypatch, save_yaml_config, tmpdir, 'handlers': { 'file': { 'class': 'logging.handlers.WatchedFileHandler', - 'filename': '""" + logfile + """', + 'filename': '""" + + logfile + + """', 'formatter': 'verbose', }, }, @@ -163,236 +164,252 @@ def test_load_custom_logging_config(self, monkeypatch, save_yaml_config, tmpdir, }, } } - """)) + """ + ) + ) config = { - 'irrd': { - 'database_url': 'db-url', - 'redis_url': 'redis-url', - 'piddir': str(tmpdir), - 'email': { - 'from': 'example@example.com', - 'smtp': '192.0.2.1' - }, - 'rpki': { - 'roa_source': None, + "irrd": { + "database_url": "db-url", + "redis_url": "redis-url", + "piddir": str(tmpdir), + "email": {"from": "example@example.com", "smtp": "192.0.2.1"}, + "rpki": { + "roa_source": None, }, - 'auth': { - 'gnupg_keyring': str(tmpdir) + "auth": {"gnupg_keyring": str(tmpdir)}, + "log": { + "logging_config_path": logging_config_path, }, - 'log': { - 'logging_config_path': logging_config_path, - }, - } } save_yaml_config(config) assert is_config_initialised() - assert get_configuration().logging_config['handlers']['file']['filename'] == logfile + assert get_configuration().logging_config["handlers"]["file"]["filename"] == logfile def test_load_valid_reload_invalid_config(self, save_yaml_config, tmpdir, caplog): - save_yaml_config({ - 'irrd': { - 'database_url': 'db-url', - 'redis_url': 'redis-url', - 'piddir': str(tmpdir), - 'email': { - 'from': 'example@example.com', - 'smtp': '192.0.2.1' - }, - 'access_lists': { - 'valid-list': { - '192/24', - '192.0.2.1', - '2001:db8::32', - '2001:db8::1', - } - }, - 'auth': { - 'gnupg_keyring': str(tmpdir), - }, - 'rpki': { - 'roa_source': 'https://example.com/roa.json', - 'notify_invalid_enabled': False, - }, - 'sources_default': ['TESTDB2', 'TESTDB', 'RPKI'], - 'sources': { - 'TESTDB': { - 'authoritative': True, - 'keep_journal': True, + save_yaml_config( + { + "irrd": { + "database_url": "db-url", + "redis_url": "redis-url", + "piddir": str(tmpdir), + "email": {"from": "example@example.com", "smtp": "192.0.2.1"}, + "access_lists": { + "valid-list": { + "192/24", + "192.0.2.1", + "2001:db8::32", + "2001:db8::1", + } }, - 'TESTDB2': { - 'nrtm_host': '192.0.2.1', - 'nrtm_port': 43, - 'import_serial_source': 'ftp://example.com/serial', - 'keep_journal': True, - 'import_timer': '1234', + "auth": { + "gnupg_keyring": str(tmpdir), }, - }, - + "rpki": { + "roa_source": "https://example.com/roa.json", + "notify_invalid_enabled": False, + }, + "sources_default": ["TESTDB2", "TESTDB", "RPKI"], + "sources": { + "TESTDB": { + "authoritative": True, + "keep_journal": True, + }, + "TESTDB2": { + "nrtm_host": "192.0.2.1", + "nrtm_port": 43, + "import_serial_source": "ftp://example.com/serial", + "keep_journal": True, + "import_timer": "1234", + }, + }, + } } - }) + ) save_yaml_config({}, run_init=False) os.kill(os.getpid(), signal.SIGHUP) - assert list(get_setting('sources_default')) == ['TESTDB2', 'TESTDB', 'RPKI'] - assert 'Errors found in configuration, continuing with current settings' in caplog.text + assert list(get_setting("sources_default")) == ["TESTDB2", "TESTDB", "RPKI"] + assert "Errors found in configuration, continuing with current settings" in caplog.text assert 'Could not find root item "irrd"' in caplog.text def test_load_invalid_config(self, save_yaml_config, tmpdir): config = { - 'irrd': { - 'database_readonly': True, - 'piddir': str(tmpdir + '/does-not-exist'), - 'user': 'a', - 'server': { - 'whois': { - 'access_list': 'doesnotexist', + "irrd": { + "database_readonly": True, + "piddir": str(tmpdir + "/does-not-exist"), + "user": "a", + "server": { + "whois": { + "access_list": "doesnotexist", }, - 'http': { - 'status_access_list': ['foo'], + "http": { + "status_access_list": ["foo"], }, }, - 'email': { - 'footer': {'a': 1}, - 'recipient_override': 'invalid-mail', + "email": { + "footer": {"a": 1}, + "recipient_override": "invalid-mail", }, - 'access_lists': { - 'bad-list': { - '192.0.2.2.1' - }, + "access_lists": { + "bad-list": {"192.0.2.2.1"}, }, - 'auth': { - 'set_creation': { - 'as-set': { - 'prefix_required': 'not-a-bool', - 'autnum_authentication': 'unknown-value', + "auth": { + "set_creation": { + "as-set": { + "prefix_required": "not-a-bool", + "autnum_authentication": "unknown-value", }, - 'not-a-real-set': { - 'prefix_required': True, + "not-a-real-set": { + "prefix_required": True, }, }, - 'password_hashers': { - 'unknown-hasher': 'legacy', - 'crypt-pw': 'invalid-setting', + "password_hashers": { + "unknown-hasher": "legacy", + "crypt-pw": "invalid-setting", }, }, - 'route_object_preference': { - 'update_timer': 'not-a-number', + "route_object_preference": { + "update_timer": "not-a-number", }, - 'rpki': { - 'roa_source': 'https://example.com/roa.json', - 'roa_import_timer': 'foo', - 'notify_invalid_subject': [], - 'notify_invalid_header': [], + "rpki": { + "roa_source": "https://example.com/roa.json", + "roa_import_timer": "foo", + "notify_invalid_subject": [], + "notify_invalid_header": [], }, - 'scopefilter': { - 'prefixes': ['invalid-prefix'], - 'asns': ['invalid', '10-invalid'], + "scopefilter": { + "prefixes": ["invalid-prefix"], + "asns": ["invalid", "10-invalid"], }, - 'sources_default': ['DOESNOTEXIST-DB'], - 'sources': { - 'TESTDB': { - 'keep_journal': True, - 'import_timer': 'foo', - 'export_timer': 'bar', - 'nrtm_host': '192.0.2.1', - 'unknown': True, - 'suspension_enabled': True, - 'nrtm_query_serial_range_limit': 'not-a-number', + "sources_default": ["DOESNOTEXIST-DB"], + "sources": { + "TESTDB": { + "keep_journal": True, + "import_timer": "foo", + "export_timer": "bar", + "nrtm_host": "192.0.2.1", + "unknown": True, + "suspension_enabled": True, + "nrtm_query_serial_range_limit": "not-a-number", }, - 'TESTDB2': { - 'authoritative': True, - 'nrtm_host': '192.0.2.1', - 'nrtm_port': 'not a number', - 'nrtm_access_list': 'invalid-list', + "TESTDB2": { + "authoritative": True, + "nrtm_host": "192.0.2.1", + "nrtm_port": "not a number", + "nrtm_access_list": "invalid-list", }, - 'TESTDB3': { - 'authoritative': True, - 'import_source': '192.0.2.1', - 'nrtm_access_list_unfiltered': 'invalid-list', - 'route_object_preference': 'not-a-number', + "TESTDB3": { + "authoritative": True, + "import_source": "192.0.2.1", + "nrtm_access_list_unfiltered": "invalid-list", + "route_object_preference": "not-a-number", }, # Not permitted, rpki.roa_source is set - 'RPKI': {}, - 'lowercase': {}, - 'invalid char': {}, + "RPKI": {}, + "lowercase": {}, + "invalid char": {}, }, - 'log': { - 'level': 'INVALID', - 'logging_config_path': 'path', - 'unknown': True, + "log": { + "level": "INVALID", + "logging_config_path": "path", + "unknown": True, }, - 'unknown_setting': False, + "unknown_setting": False, } } with pytest.raises(ConfigurationError) as ce: save_yaml_config(config) - assert 'Setting database_url is required.' in str(ce.value) - assert 'Setting redis_url is required.' in str(ce.value) - assert 'Setting piddir is required and must point to an existing directory.' in str(ce.value) - assert 'Setting email.from is required and must be an email address.' in str(ce.value) - assert 'Setting email.smtp is required.' in str(ce.value) - assert 'Setting email.footer must be a string, if defined.' in str(ce.value) - assert 'Setting email.recipient_override must be an email address if set.' in str(ce.value) - assert 'Settings user and group must both be defined, or neither.' in str(ce.value) - assert 'Setting auth.gnupg_keyring is required.' in str(ce.value) - assert 'Unknown setting key: auth.set_creation.not-a-real-set.prefix_required' in str(ce.value) - assert 'Setting auth.set_creation.as-set.prefix_required must be a bool' in str(ce.value) - assert 'Setting auth.set_creation.as-set.autnum_authentication must be one of' in str(ce.value) - assert 'Unknown setting key: auth.password_hashers.unknown-hash' in str(ce.value) - assert 'Setting auth.password_hashers.crypt-pw must be one of' in str(ce.value) - assert 'Access lists doesnotexist, invalid-list referenced in settings, but not defined.' in str(ce.value) - assert 'Setting server.http.status_access_list must be a string, if defined.' in str(ce.value) - assert 'Invalid item in access list bad-list: IPv4 Address with more than 4 bytes.' in str(ce.value) - assert 'Invalid item in prefix scopefilter: invalid-prefix' in str(ce.value) - assert 'Invalid item in asn scopefilter: invalid.' in str(ce.value) - assert 'Invalid item in asn scopefilter: 10-invalid.' in str(ce.value) - assert 'Setting sources contains reserved source name: RPKI' in str(ce.value) - assert 'Setting suspension_enabled for source TESTDB can not be enabled without enabling authoritative.' in str(ce.value) - assert 'Setting keep_journal for source TESTDB can not be enabled unless either ' in str(ce.value) - assert 'Setting nrtm_host for source TESTDB can not be enabled without setting import_serial_source.' in str(ce.value) - assert 'Setting authoritative for source TESTDB2 can not be enabled when either nrtm_host or import_source are set.' in str(ce.value) - assert 'Setting authoritative for source TESTDB3 can not be enabled when either nrtm_host or import_source are set.' in str(ce.value) - assert 'Source TESTDB can not have authoritative, import_source or nrtm_host set when database_readonly is enabled.' in str(ce.value) - assert 'Source TESTDB3 can not have authoritative, import_source or nrtm_host set when database_readonly is enabled.' in str(ce.value) - assert 'Setting nrtm_port for source TESTDB2 must be a number.' in str(ce.value) - assert 'Setting rpki.roa_import_timer must be set to a number.' in str(ce.value) - assert 'Setting rpki.notify_invalid_subject must be a string, if defined.' in str(ce.value) - assert 'Setting rpki.notify_invalid_header must be a string, if defined.' in str(ce.value) - assert 'Setting import_timer for source TESTDB must be a number.' in str(ce.value) - assert 'Setting export_timer for source TESTDB must be a number.' in str(ce.value) - assert 'Setting route_object_preference for source TESTDB3 must be a number.' in str(ce.value) - assert 'Setting route_object_preference.update_timer must be a number.' in str(ce.value) - assert 'Setting nrtm_query_serial_range_limit for source TESTDB must be a number.' in str(ce.value) - assert 'Invalid source name: lowercase' in str(ce.value) - assert 'Invalid source name: invalid char' in str(ce.value) - assert 'but rpki.notify_invalid_enabled is not set' in str(ce.value) - assert 'Setting sources_default contains unknown sources: DOESNOTEXIST-DB' in str(ce.value) - assert 'Invalid log.level: INVALID' in str(ce.value) - assert 'Setting log.logging_config_path can not be combined' in str(ce.value) - assert 'Unknown setting key: unknown_setting' in str(ce.value) - assert 'Unknown setting key: log.unknown' in str(ce.value) - assert 'Unknown key(s) under source TESTDB: unknown' in str(ce.value) + assert "Setting database_url is required." in str(ce.value) + assert "Setting redis_url is required." in str(ce.value) + assert "Setting piddir is required and must point to an existing directory." in str(ce.value) + assert "Setting email.from is required and must be an email address." in str(ce.value) + assert "Setting email.smtp is required." in str(ce.value) + assert "Setting email.footer must be a string, if defined." in str(ce.value) + assert "Setting email.recipient_override must be an email address if set." in str(ce.value) + assert "Settings user and group must both be defined, or neither." in str(ce.value) + assert "Setting auth.gnupg_keyring is required." in str(ce.value) + assert "Unknown setting key: auth.set_creation.not-a-real-set.prefix_required" in str(ce.value) + assert "Setting auth.set_creation.as-set.prefix_required must be a bool" in str(ce.value) + assert "Setting auth.set_creation.as-set.autnum_authentication must be one of" in str(ce.value) + assert "Unknown setting key: auth.password_hashers.unknown-hash" in str(ce.value) + assert "Setting auth.password_hashers.crypt-pw must be one of" in str(ce.value) + assert "Access lists doesnotexist, invalid-list referenced in settings, but not defined." in str( + ce.value + ) + assert "Setting server.http.status_access_list must be a string, if defined." in str(ce.value) + assert "Invalid item in access list bad-list: IPv4 Address with more than 4 bytes." in str(ce.value) + assert "Invalid item in prefix scopefilter: invalid-prefix" in str(ce.value) + assert "Invalid item in asn scopefilter: invalid." in str(ce.value) + assert "Invalid item in asn scopefilter: 10-invalid." in str(ce.value) + assert "Setting sources contains reserved source name: RPKI" in str(ce.value) + assert ( + "Setting suspension_enabled for source TESTDB can not be enabled without enabling authoritative." + in str(ce.value) + ) + assert "Setting keep_journal for source TESTDB can not be enabled unless either " in str(ce.value) + assert ( + "Setting nrtm_host for source TESTDB can not be enabled without setting import_serial_source." + in str(ce.value) + ) + assert ( + "Setting authoritative for source TESTDB2 can not be enabled when either nrtm_host or" + " import_source are set." + in str(ce.value) + ) + assert ( + "Setting authoritative for source TESTDB3 can not be enabled when either nrtm_host or" + " import_source are set." + in str(ce.value) + ) + assert ( + "Source TESTDB can not have authoritative, import_source or nrtm_host set when database_readonly" + " is enabled." + in str(ce.value) + ) + assert ( + "Source TESTDB3 can not have authoritative, import_source or nrtm_host set when database_readonly" + " is enabled." + in str(ce.value) + ) + assert "Setting nrtm_port for source TESTDB2 must be a number." in str(ce.value) + assert "Setting rpki.roa_import_timer must be set to a number." in str(ce.value) + assert "Setting rpki.notify_invalid_subject must be a string, if defined." in str(ce.value) + assert "Setting rpki.notify_invalid_header must be a string, if defined." in str(ce.value) + assert "Setting import_timer for source TESTDB must be a number." in str(ce.value) + assert "Setting export_timer for source TESTDB must be a number." in str(ce.value) + assert "Setting route_object_preference for source TESTDB3 must be a number." in str(ce.value) + assert "Setting route_object_preference.update_timer must be a number." in str(ce.value) + assert "Setting nrtm_query_serial_range_limit for source TESTDB must be a number." in str(ce.value) + assert "Invalid source name: lowercase" in str(ce.value) + assert "Invalid source name: invalid char" in str(ce.value) + assert "but rpki.notify_invalid_enabled is not set" in str(ce.value) + assert "Setting sources_default contains unknown sources: DOESNOTEXIST-DB" in str(ce.value) + assert "Invalid log.level: INVALID" in str(ce.value) + assert "Setting log.logging_config_path can not be combined" in str(ce.value) + assert "Unknown setting key: unknown_setting" in str(ce.value) + assert "Unknown setting key: log.unknown" in str(ce.value) + assert "Unknown key(s) under source TESTDB: unknown" in str(ce.value) class TestGetSetting: - setting_name = 'server.whois.interface' - env_name = 'IRRD_SERVER_WHOIS_INTERFACE' + setting_name = "server.whois.interface" + env_name = "IRRD_SERVER_WHOIS_INTERFACE" def test_get_setting_default(self, monkeypatch): monkeypatch.delenv(self.env_name, raising=False) - assert get_setting(self.setting_name) == '::0' + assert get_setting(self.setting_name) == "::0" def test_get_setting_env(self, monkeypatch): - monkeypatch.setenv(self.env_name, 'env_value') - assert get_setting(self.setting_name) == 'env_value' + monkeypatch.setenv(self.env_name, "env_value") + assert get_setting(self.setting_name) == "env_value" def test_get_setting_unknown(self, monkeypatch): with pytest.raises(ValueError): - get_setting('unknown') + get_setting("unknown") with pytest.raises(ValueError): - get_setting('log.unknown') + get_setting("log.unknown") with pytest.raises(ValueError): - get_setting('sources.TEST.unknown') + get_setting("sources.TEST.unknown") diff --git a/irrd/daemon/main.py b/irrd/daemon/main.py index 34bbf2f96..a79975531 100755 --- a/irrd/daemon/main.py +++ b/irrd/daemon/main.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # flake8: noqa: E402 import argparse +import grp import logging import multiprocessing import os @@ -9,10 +10,9 @@ import sys import time from pathlib import Path -from typing import Tuple, Optional +from typing import Optional, Tuple import daemon -import grp import psutil from daemon.daemon import change_process_owner from pid import PidFile, PidFileError @@ -20,54 +20,66 @@ logger = logging.getLogger(__name__) sys.path.append(str(Path(__file__).resolve().parents[2])) -from irrd.utils.process_support import ExceptionLoggingProcess, set_traceback_handler -from irrd.storage.preload import PreloadStoreManager -from irrd.server.whois.server import start_whois_server -from irrd.server.http.server import run_http_server +from irrd import ENV_MAIN_PROCESS_PID, __version__ +from irrd.conf import CONFIG_PATH_DEFAULT, config_init, get_configuration, get_setting from irrd.mirroring.scheduler import MirrorScheduler -from irrd.conf import config_init, CONFIG_PATH_DEFAULT, get_setting, get_configuration -from irrd import __version__, ENV_MAIN_PROCESS_PID - +from irrd.server.http.server import run_http_server +from irrd.server.whois.server import start_whois_server +from irrd.storage.preload import PreloadStoreManager +from irrd.utils.process_support import ExceptionLoggingProcess, set_traceback_handler # This file does not have a unit test, but is instead tested through # the integration tests. Writing a unit test would be too complex. + def main(): description = """IRRd main process""" parser = argparse.ArgumentParser(description=description) - parser.add_argument('--config', dest='config_file_path', type=str, - help=f'use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})') - parser.add_argument('--foreground', dest='foreground', action='store_true', - help=f"run IRRd in the foreground, don't detach") + parser.add_argument( + "--config", + dest="config_file_path", + type=str, + help=f"use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})", + ) + parser.add_argument( + "--foreground", + dest="foreground", + action="store_true", + help=f"run IRRd in the foreground, don't detach", + ) args = parser.parse_args() - mirror_frequency = int(os.environ.get('IRRD_SCHEDULER_TIMER_OVERRIDE', 15)) + mirror_frequency = int(os.environ.get("IRRD_SCHEDULER_TIMER_OVERRIDE", 15)) daemon_kwargs = { - 'umask': 0o022, + "umask": 0o022, } if args.foreground: - daemon_kwargs['detach_process'] = False - daemon_kwargs['stdout'] = sys.stdout - daemon_kwargs['stderr'] = sys.stderr + daemon_kwargs["detach_process"] = False + daemon_kwargs["stdout"] = sys.stdout + daemon_kwargs["stderr"] = sys.stderr # Since Python 3.8, the default method is spawn for MacOS, # which creates several issues. For consistency, we force to fork. - multiprocessing.set_start_method('fork') + multiprocessing.set_start_method("fork") # config_init with commit may only be called within DaemonContext, # but this call here causes fast failure for most misconfigurations config_init(args.config_file_path, commit=False) - staged_logfile_path = get_configuration().user_config_staging.get('log.logfile_path') - staged_logging_config_path = get_configuration().user_config_staging.get('log.logging_config_path') - if not any([ - staged_logfile_path, - staged_logging_config_path, - args.foreground, - ]): - logging.critical('Unable to start: when not running in the foreground, you must set ' - 'either log.logfile_path or log.logging_config_path in the settings') + staged_logfile_path = get_configuration().user_config_staging.get("log.logfile_path") + staged_logging_config_path = get_configuration().user_config_staging.get("log.logging_config_path") + if not any( + [ + staged_logfile_path, + staged_logging_config_path, + args.foreground, + ] + ): + logging.critical( + "Unable to start: when not running in the foreground, you must set " + "either log.logfile_path or log.logging_config_path in the settings" + ) return uid, gid = get_configured_owner(from_staging=True) @@ -75,7 +87,9 @@ def main(): os.setegid(gid) os.seteuid(uid) if staged_logfile_path and not os.access(staged_logfile_path, os.W_OK, effective_ids=True): - logging.critical(f'Unable to start: logfile {staged_logfile_path} not writable by UID {uid} / GID {gid}') + logging.critical( + f"Unable to start: logfile {staged_logfile_path} not writable by UID {uid} / GID {gid}" + ) return with daemon.DaemonContext(**daemon_kwargs): @@ -83,25 +97,27 @@ def main(): uid, gid = get_configured_owner() # Running as root is permitted on CI - if not os.environ.get('CI') and not uid and os.geteuid() == 0: - logging.critical('Unable to start: user and group must be defined in settings ' - 'when starting IRRd as root') + if not os.environ.get("CI") and not uid and os.geteuid() == 0: + logging.critical( + "Unable to start: user and group must be defined in settings when starting IRRd as root" + ) return - piddir = get_setting('piddir') - logger.info('IRRd attempting to secure PID') + piddir = get_setting("piddir") + logger.info("IRRd attempting to secure PID") try: - with PidFile(pidname='irrd', piddir=piddir): - logger.info(f'IRRd {__version__} starting, PID {os.getpid()}, PID file in {piddir}') - run_irrd(mirror_frequency=mirror_frequency, - config_file_path=args.config_file_path if args.config_file_path else CONFIG_PATH_DEFAULT, - uid=uid, - gid=gid, - ) + with PidFile(pidname="irrd", piddir=piddir): + logger.info(f"IRRd {__version__} starting, PID {os.getpid()}, PID file in {piddir}") + run_irrd( + mirror_frequency=mirror_frequency, + config_file_path=args.config_file_path if args.config_file_path else CONFIG_PATH_DEFAULT, + uid=uid, + gid=gid, + ) except PidFileError as pfe: - logger.error(f'Failed to start IRRd, unable to lock PID file irrd.pid in {piddir}: {pfe}') + logger.error(f"Failed to start IRRd, unable to lock PID file irrd.pid in {piddir}: {pfe}") except Exception as e: - logger.error(f'Error occurred in main process, terminating. Error follows:') + logger.error(f"Error occurred in main process, terminating. Error follows:") logger.exception(e) os.kill(os.getpid(), signal.SIGTERM) @@ -117,9 +133,7 @@ def run_irrd(mirror_frequency: int, config_file_path: str, uid: Optional[int], g set_traceback_handler() whois_process = ExceptionLoggingProcess( - target=start_whois_server, - name='irrd-whois-server-listener', - kwargs={'uid': uid, 'gid': gid} + target=start_whois_server, name="irrd-whois-server-listener", kwargs={"uid": uid, "gid": gid} ) whois_process.start() if uid and gid: @@ -128,11 +142,13 @@ def run_irrd(mirror_frequency: int, config_file_path: str, uid: Optional[int], g mirror_scheduler = MirrorScheduler() preload_manager = None - if not get_setting(f'database_readonly'): - preload_manager = PreloadStoreManager(name='irrd-preload-store-manager') + if not get_setting(f"database_readonly"): + preload_manager = PreloadStoreManager(name="irrd-preload-store-manager") preload_manager.start() - uvicorn_process = ExceptionLoggingProcess(target=run_http_server, name='irrd-http-server-listener', args=(config_file_path, )) + uvicorn_process = ExceptionLoggingProcess( + target=run_http_server, name="irrd-http-server-listener", args=(config_file_path,) + ) uvicorn_process.start() def sighup_handler(signum, frame): @@ -145,8 +161,11 @@ def sighup_handler(signum, frame): for process in children: process.send_signal(signal.SIGHUP) if children: - logging.info('Main process received SIGHUP with valid config, sent SIGHUP to ' - f'child processes {[c.pid for c in children]}') + logging.info( + "Main process received SIGHUP with valid config, sent SIGHUP to " + f"child processes {[c.pid for c in children]}" + ) + signal.signal(signal.SIGHUP, sighup_handler) def sigterm_handler(signum, frame): @@ -161,11 +180,13 @@ def sigterm_handler(signum, frame): # do the best we can. pass if children: - logging.info('Main process received SIGTERM, sent SIGTERM to ' - f'child processes {[c.pid for c in children]}') + logging.info( + f"Main process received SIGTERM, sent SIGTERM to child processes {[c.pid for c in children]}" + ) nonlocal terminated terminated = True + signal.signal(signal.SIGTERM, sigterm_handler) sleeps = mirror_frequency @@ -178,7 +199,7 @@ def sigterm_handler(signum, frame): time.sleep(1) sleeps += 1 - logging.debug(f'Main process waiting for child processes to terminate') + logging.debug(f"Main process waiting for child processes to terminate") for child_process in whois_process, uvicorn_process, preload_manager: if child_process: child_process.join(timeout=3) @@ -191,27 +212,29 @@ def sigterm_handler(signum, frame): except Exception: pass if children: - logging.info('Some processes left alive after SIGTERM, send SIGKILL to ' - f'child processes {[c.pid for c in children]}') + logging.info( + "Some processes left alive after SIGTERM, send SIGKILL to " + f"child processes {[c.pid for c in children]}" + ) - logging.info(f'Main process exiting') + logging.info(f"Main process exiting") def get_configured_owner(from_staging=False) -> Tuple[Optional[int], Optional[int]]: uid = gid = None if not from_staging: - user = get_setting('user') - group = get_setting('group') + user = get_setting("user") + group = get_setting("group") else: config = get_configuration() assert config - user = config.user_config_staging.get('user') - group = config.user_config_staging.get('group') + user = config.user_config_staging.get("user") + group = config.user_config_staging.get("group") if user and group: uid = pwd.getpwnam(user).pw_uid gid = grp.getgrnam(group).gr_gid return uid, gid -if __name__ == '__main__': # pragma: no cover +if __name__ == "__main__": # pragma: no cover main() diff --git a/irrd/integration_tests/constants.py b/irrd/integration_tests/constants.py index c3e368cdf..94b19c350 100644 --- a/irrd/integration_tests/constants.py +++ b/irrd/integration_tests/constants.py @@ -1,5 +1,5 @@ -EMAIL_SEPARATOR = '\n=-=-=-=-=-=-=-=-=-=-MESSAGE BREAK=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-\n' -EMAIL_END = b'\n=-=-=-=-=-=-=-=-=-=-END OF MESSAGES=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-\n' -EMAIL_RETURN_MSGS_COMMAND = 'TEST_RETURN_MESSAGES' -EMAIL_DISCARD_MSGS_COMMAND = 'TEST_DISCARD_MESSAGES' +EMAIL_SEPARATOR = "\n=-=-=-=-=-=-=-=-=-=-MESSAGE BREAK=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-\n" +EMAIL_END = b"\n=-=-=-=-=-=-=-=-=-=-END OF MESSAGES=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-\n" +EMAIL_RETURN_MSGS_COMMAND = "TEST_RETURN_MESSAGES" +EMAIL_DISCARD_MSGS_COMMAND = "TEST_DISCARD_MESSAGES" EMAIL_SMTP_PORT = 2500 diff --git a/irrd/integration_tests/run.py b/irrd/integration_tests/run.py index 9c3805ee3..b496d016c 100644 --- a/irrd/integration_tests/run.py +++ b/irrd/integration_tests/run.py @@ -1,34 +1,53 @@ -# flake8: noqa: W293 -import sys -import time -import unittest - -import ujson - import base64 import email import os -import requests import signal import socket -import sqlalchemy as sa import subprocess +import sys import textwrap -import yaml -from alembic import command, config +import time +import unittest from pathlib import Path +import requests +import sqlalchemy as sa +import ujson +import yaml +from alembic import command, config from python_graphql_client import GraphqlClient -from irrd.conf import config_init, PASSWORD_HASH_DUMMY_VALUE -from irrd.utils.rpsl_samples import (SAMPLE_MNTNER, SAMPLE_PERSON, SAMPLE_KEY_CERT, SIGNED_PERSON_UPDATE_VALID, - SAMPLE_AS_SET, SAMPLE_AUT_NUM, SAMPLE_DOMAIN, SAMPLE_FILTER_SET, SAMPLE_INET_RTR, - SAMPLE_INET6NUM, SAMPLE_INETNUM, SAMPLE_PEERING_SET, SAMPLE_ROLE, SAMPLE_ROUTE, - SAMPLE_ROUTE_SET, SAMPLE_ROUTE6, SAMPLE_RTR_SET, SAMPLE_AS_BLOCK) +from irrd.conf import PASSWORD_HASH_DUMMY_VALUE, config_init +from irrd.utils.rpsl_samples import ( + SAMPLE_AS_BLOCK, + SAMPLE_AS_SET, + SAMPLE_AUT_NUM, + SAMPLE_DOMAIN, + SAMPLE_FILTER_SET, + SAMPLE_INET6NUM, + SAMPLE_INET_RTR, + SAMPLE_INETNUM, + SAMPLE_KEY_CERT, + SAMPLE_MNTNER, + SAMPLE_PEERING_SET, + SAMPLE_PERSON, + SAMPLE_ROLE, + SAMPLE_ROUTE, + SAMPLE_ROUTE6, + SAMPLE_ROUTE_SET, + SAMPLE_RTR_SET, + SIGNED_PERSON_UPDATE_VALID, +) from irrd.utils.whois_client import whois_query, whois_query_irrd -from .constants import (EMAIL_SMTP_PORT, EMAIL_DISCARD_MSGS_COMMAND, EMAIL_RETURN_MSGS_COMMAND, EMAIL_SEPARATOR, - EMAIL_END) + from ..storage import translate_url +from .constants import ( + EMAIL_DISCARD_MSGS_COMMAND, + EMAIL_END, + EMAIL_RETURN_MSGS_COMMAND, + EMAIL_SEPARATOR, + EMAIL_SMTP_PORT, +) IRRD_ROOT_PATH = str(Path(__file__).resolve().parents[2]) sys.path.append(IRRD_ROOT_PATH) @@ -46,29 +65,31 @@ remarks: remark """ -SAMPLE_MNTNER_CLEAN = SAMPLE_MNTNER.replace('mnt-by: OTHER1-MNT,OTHER2-MNT\n', '') -LARGE_UPDATE = '\n\n'.join([ - SAMPLE_AS_BLOCK, - SAMPLE_AS_SET, - SAMPLE_AUT_NUM, - SAMPLE_AUT_NUM.replace('aut-num: as065537', 'aut-num: as65538'), - SAMPLE_AUT_NUM.replace('aut-num: as065537', 'aut-num: as65539'), - SAMPLE_AUT_NUM.replace('aut-num: as065537', 'aut-num: as65540'), - SAMPLE_DOMAIN, - SAMPLE_FILTER_SET, - SAMPLE_INET_RTR, - SAMPLE_INET6NUM, - SAMPLE_INETNUM, - SAMPLE_KEY_CERT, - SAMPLE_PEERING_SET, - SAMPLE_PERSON.replace('PERSON-TEST', 'DUMY2-TEST'), - SAMPLE_ROLE, - SAMPLE_ROUTE, - SAMPLE_ROUTE_SET, - SAMPLE_ROUTE6, - SAMPLE_RTR_SET, - AS_SET_REFERRING_OTHER_SET, -]) +SAMPLE_MNTNER_CLEAN = SAMPLE_MNTNER.replace("mnt-by: OTHER1-MNT,OTHER2-MNT\n", "") +LARGE_UPDATE = "\n\n".join( + [ + SAMPLE_AS_BLOCK, + SAMPLE_AS_SET, + SAMPLE_AUT_NUM, + SAMPLE_AUT_NUM.replace("aut-num: as065537", "aut-num: as65538"), + SAMPLE_AUT_NUM.replace("aut-num: as065537", "aut-num: as65539"), + SAMPLE_AUT_NUM.replace("aut-num: as065537", "aut-num: as65540"), + SAMPLE_DOMAIN, + SAMPLE_FILTER_SET, + SAMPLE_INET_RTR, + SAMPLE_INET6NUM, + SAMPLE_INETNUM, + SAMPLE_KEY_CERT, + SAMPLE_PEERING_SET, + SAMPLE_PERSON.replace("PERSON-TEST", "DUMY2-TEST"), + SAMPLE_ROLE, + SAMPLE_ROUTE, + SAMPLE_ROUTE_SET, + SAMPLE_ROUTE6, + SAMPLE_RTR_SET, + AS_SET_REFERRING_OTHER_SET, + ] +) class TestIntegration: @@ -81,6 +102,7 @@ class TestIntegration: Note that this test will not be included in the default py.test discovery, this is intentional. """ + port_http1 = 6080 port_whois1 = 6043 port_http2 = 6081 @@ -89,88 +111,103 @@ class TestIntegration: def test_irrd_integration(self, tmpdir): self.assertCountEqual = unittest.TestCase().assertCountEqual # IRRD_DATABASE_URL and IRRD_REDIS_URL override the yaml config, so should be removed - if 'IRRD_DATABASE_URL' in os.environ: - del os.environ['IRRD_DATABASE_URL'] - if 'IRRD_REDIS_URL' in os.environ: - del os.environ['IRRD_REDIS_URL'] + if "IRRD_DATABASE_URL" in os.environ: + del os.environ["IRRD_DATABASE_URL"] + if "IRRD_REDIS_URL" in os.environ: + del os.environ["IRRD_REDIS_URL"] # PYTHONPATH needs to contain the twisted plugin path to support the mailserver. - os.environ['PYTHONPATH'] = IRRD_ROOT_PATH - os.environ['IRRD_SCHEDULER_TIMER_OVERRIDE'] = '1' + os.environ["PYTHONPATH"] = IRRD_ROOT_PATH + os.environ["IRRD_SCHEDULER_TIMER_OVERRIDE"] = "1" self.tmpdir = tmpdir self._start_mailserver() self._start_irrds() # Attempt to load a mntner with valid auth, but broken references. - self._submit_update(self.config_path1, SAMPLE_MNTNER + '\n\noverride: override-password') + self._submit_update(self.config_path1, SAMPLE_MNTNER + "\n\noverride: override-password") messages = self._retrieve_mails() assert len(messages) == 1 mail_text = self._extract_message_body(messages[0]) - assert messages[0]['Subject'] == 'FAILED: my subject' - assert messages[0]['From'] == 'from@example.com' - assert messages[0]['To'] == 'Sasha ' - assert '\nCreate FAILED: [mntner] TEST-MNT\n' in mail_text - assert '\nERROR: Object PERSON-TEST referenced in field admin-c not found in database TEST - must reference one of role, person.\n' in mail_text - assert '\nERROR: Object OTHER1-MNT referenced in field mnt-by not found in database TEST - must reference mntner.\n' in mail_text - assert '\nERROR: Object OTHER2-MNT referenced in field mnt-by not found in database TEST - must reference mntner.\n' in mail_text - assert 'email footer' in mail_text - assert 'Generated by IRRd version ' in mail_text + assert messages[0]["Subject"] == "FAILED: my subject" + assert messages[0]["From"] == "from@example.com" + assert messages[0]["To"] == "Sasha " + assert "\nCreate FAILED: [mntner] TEST-MNT\n" in mail_text + assert ( + "\nERROR: Object PERSON-TEST referenced in field admin-c not found in database TEST - must" + " reference one of role, person.\n" + in mail_text + ) + assert ( + "\nERROR: Object OTHER1-MNT referenced in field mnt-by not found in database TEST - must" + " reference mntner.\n" + in mail_text + ) + assert ( + "\nERROR: Object OTHER2-MNT referenced in field mnt-by not found in database TEST - must" + " reference mntner.\n" + in mail_text + ) + assert "email footer" in mail_text + assert "Generated by IRRd version " in mail_text # Load a regular valid mntner and person into the DB, and verify # the contents of the result. - self._submit_update(self.config_path1, - SAMPLE_MNTNER_CLEAN + '\n\n' + SAMPLE_PERSON + '\n\noverride: override-password') + self._submit_update( + self.config_path1, + SAMPLE_MNTNER_CLEAN + "\n\n" + SAMPLE_PERSON + "\n\noverride: override-password", + ) messages = self._retrieve_mails() assert len(messages) == 1 mail_text = self._extract_message_body(messages[0]) - assert messages[0]['Subject'] == 'SUCCESS: my subject' - assert messages[0]['From'] == 'from@example.com' - assert messages[0]['To'] == 'Sasha ' - assert '\nCreate succeeded: [mntner] TEST-MNT\n' in mail_text - assert '\nCreate succeeded: [person] PERSON-TEST\n' in mail_text - assert 'email footer' in mail_text - assert 'Generated by IRRd version ' in mail_text + assert messages[0]["Subject"] == "SUCCESS: my subject" + assert messages[0]["From"] == "from@example.com" + assert messages[0]["To"] == "Sasha " + assert "\nCreate succeeded: [mntner] TEST-MNT\n" in mail_text + assert "\nCreate succeeded: [person] PERSON-TEST\n" in mail_text + assert "email footer" in mail_text + assert "Generated by IRRd version " in mail_text # Check whether the objects can be queried from irrd #1, # whether the hash is masked, and whether encoding is correct. - mntner_text = whois_query('127.0.0.1', self.port_whois1, 'TEST-MNT') - assert 'TEST-MNT' in mntner_text + mntner_text = whois_query("127.0.0.1", self.port_whois1, "TEST-MNT") + assert "TEST-MNT" in mntner_text assert PASSWORD_HASH_DUMMY_VALUE in mntner_text - assert 'unįcöde tæst 🌈🦄' in mntner_text - assert 'PERSON-TEST' in mntner_text + assert "unįcöde tæst 🌈🦄" in mntner_text + assert "PERSON-TEST" in mntner_text # After three seconds, a new export should have been generated by irrd #1, # loaded by irrd #2, and the objects should be available in irrd #2 time.sleep(3) - mntner_text = whois_query('127.0.0.1', self.port_whois2, 'TEST-MNT') - assert 'TEST-MNT' in mntner_text + mntner_text = whois_query("127.0.0.1", self.port_whois2, "TEST-MNT") + assert "TEST-MNT" in mntner_text assert PASSWORD_HASH_DUMMY_VALUE in mntner_text - assert 'unįcöde tæst 🌈🦄' in mntner_text - assert 'PERSON-TEST' in mntner_text + assert "unįcöde tæst 🌈🦄" in mntner_text + assert "PERSON-TEST" in mntner_text # Load a key-cert. This should cause notifications to mnt-nfy (2x). # Change is authenticated by valid password. - self._submit_update(self.config_path1, SAMPLE_KEY_CERT + '\npassword: md5-password') + self._submit_update(self.config_path1, SAMPLE_KEY_CERT + "\npassword: md5-password") messages = self._retrieve_mails() assert len(messages) == 3 - assert messages[0]['Subject'] == 'SUCCESS: my subject' - assert messages[0]['From'] == 'from@example.com' - assert messages[0]['To'] == 'Sasha ' - assert 'Create succeeded: [key-cert] PGPKEY-80F238C6' in self._extract_message_body(messages[0]) - - self._check_recipients_in_mails(messages[1:], [ - 'mnt-nfy@example.net', 'mnt-nfy2@example.net' - ]) - - self._check_text_in_mails(messages[1:], [ - '\n> Message-ID: <1325754288.4989.6.camel@hostname>\n', - '\nCreate succeeded for object below: [key-cert] PGPKEY-80F238C6:\n', - 'email footer', - 'Generated by IRRd version ', - ]) + assert messages[0]["Subject"] == "SUCCESS: my subject" + assert messages[0]["From"] == "from@example.com" + assert messages[0]["To"] == "Sasha " + assert "Create succeeded: [key-cert] PGPKEY-80F238C6" in self._extract_message_body(messages[0]) + + self._check_recipients_in_mails(messages[1:], ["mnt-nfy@example.net", "mnt-nfy2@example.net"]) + + self._check_text_in_mails( + messages[1:], + [ + "\n> Message-ID: <1325754288.4989.6.camel@hostname>\n", + "\nCreate succeeded for object below: [key-cert] PGPKEY-80F238C6:\n", + "email footer", + "Generated by IRRd version ", + ], + ) for message in messages[1:]: - assert message['Subject'] == 'Notification of TEST database changes' - assert message['From'] == 'from@example.com' + assert message["Subject"] == "Notification of TEST database changes" + assert message["From"] == "from@example.com" # Use the new PGP key to make an update to PERSON-TEST. Should # again trigger mnt-nfy messages, and a mail to the notify address @@ -179,344 +216,382 @@ def test_irrd_integration(self, tmpdir): messages = self._retrieve_mails() assert len(messages) == 4 mail_text = self._extract_message_body(messages[0]) - assert messages[0]['Subject'] == 'SUCCESS: my subject' - assert messages[0]['From'] == 'from@example.com' - assert messages[0]['To'] == 'Sasha ' - assert '\nModify succeeded: [person] PERSON-TEST\n' in mail_text - - self._check_recipients_in_mails(messages[1:], [ - 'mnt-nfy@example.net', 'mnt-nfy2@example.net', 'notify@example.com', - ]) - - self._check_text_in_mails(messages[1:], [ - '\n> Message-ID: <1325754288.4989.6.camel@hostname>\n', - '\nModify succeeded for object below: [person] PERSON-TEST:\n', - '\n@@ -1,4 +1,4 @@\n', - '\nNew version of this object:\n', - ]) + assert messages[0]["Subject"] == "SUCCESS: my subject" + assert messages[0]["From"] == "from@example.com" + assert messages[0]["To"] == "Sasha " + assert "\nModify succeeded: [person] PERSON-TEST\n" in mail_text + + self._check_recipients_in_mails( + messages[1:], + [ + "mnt-nfy@example.net", + "mnt-nfy2@example.net", + "notify@example.com", + ], + ) + + self._check_text_in_mails( + messages[1:], + [ + "\n> Message-ID: <1325754288.4989.6.camel@hostname>\n", + "\nModify succeeded for object below: [person] PERSON-TEST:\n", + "\n@@ -1,4 +1,4 @@\n", + "\nNew version of this object:\n", + ], + ) for message in messages[1:]: - assert message['Subject'] == 'Notification of TEST database changes' - assert message['From'] == 'from@example.com' + assert message["Subject"] == "Notification of TEST database changes" + assert message["From"] == "from@example.com" # Check that the person is updated on irrd #1 - person_text = whois_query('127.0.0.1', self.port_whois1, 'PERSON-TEST') - assert 'PERSON-TEST' in person_text - assert 'Test person changed by PGP signed update' in person_text + person_text = whois_query("127.0.0.1", self.port_whois1, "PERSON-TEST") + assert "PERSON-TEST" in person_text + assert "Test person changed by PGP signed update" in person_text # After 2s, NRTM from irrd #2 should have picked up the change. time.sleep(2) - person_text = whois_query('127.0.0.1', self.port_whois2, 'PERSON-TEST') - assert 'PERSON-TEST' in person_text - assert 'Test person changed by PGP signed update' in person_text + person_text = whois_query("127.0.0.1", self.port_whois2, "PERSON-TEST") + assert "PERSON-TEST" in person_text + assert "Test person changed by PGP signed update" in person_text # Submit an update back to the original person object, with an invalid # password and invalid override. Should trigger notification to upd-to. - self._submit_update(self.config_path1, SAMPLE_PERSON + '\npassword: invalid\noverride: invalid\n') + self._submit_update(self.config_path1, SAMPLE_PERSON + "\npassword: invalid\noverride: invalid\n") messages = self._retrieve_mails() assert len(messages) == 2 mail_text = self._extract_message_body(messages[0]) - assert messages[0]['Subject'] == 'FAILED: my subject' - assert messages[0]['From'] == 'from@example.com' - assert messages[0]['To'] == 'Sasha ' - assert '\nModify FAILED: [person] PERSON-TEST\n' in mail_text - assert '\nERROR: Authorisation for person PERSON-TEST failed: must be authenticated by one of: TEST-MNT\n' in mail_text + assert messages[0]["Subject"] == "FAILED: my subject" + assert messages[0]["From"] == "from@example.com" + assert messages[0]["To"] == "Sasha " + assert "\nModify FAILED: [person] PERSON-TEST\n" in mail_text + assert ( + "\nERROR: Authorisation for person PERSON-TEST failed: must be authenticated by one of:" + " TEST-MNT\n" + in mail_text + ) mail_text = self._extract_message_body(messages[1]) - assert messages[1]['Subject'] == 'Notification of TEST database changes' - assert messages[1]['From'] == 'from@example.com' - assert messages[1]['To'] == 'upd-to@example.net' - assert '\nModify FAILED AUTHORISATION for object below: [person] PERSON-TEST:\n' in mail_text + assert messages[1]["Subject"] == "Notification of TEST database changes" + assert messages[1]["From"] == "from@example.com" + assert messages[1]["To"] == "upd-to@example.net" + assert "\nModify FAILED AUTHORISATION for object below: [person] PERSON-TEST:\n" in mail_text # Object should not have changed by latest update. - person_text = whois_query('127.0.0.1', self.port_whois1, 'PERSON-TEST') - assert 'PERSON-TEST' in person_text - assert 'Test person changed by PGP signed update' in person_text + person_text = whois_query("127.0.0.1", self.port_whois1, "PERSON-TEST") + assert "PERSON-TEST" in person_text + assert "Test person changed by PGP signed update" in person_text # Submit a delete with a valid password for PERSON-TEST. # This should be rejected, because it creates a dangling reference. # No mail should be sent to upd-to. - self._submit_update(self.config_path1, SAMPLE_PERSON + 'password: md5-password\ndelete: delete\n') + self._submit_update(self.config_path1, SAMPLE_PERSON + "password: md5-password\ndelete: delete\n") messages = self._retrieve_mails() assert len(messages) == 1 mail_text = self._extract_message_body(messages[0]) - assert messages[0]['Subject'] == 'FAILED: my subject' - assert messages[0]['From'] == 'from@example.com' - assert messages[0]['To'] == 'Sasha ' - assert '\nDelete FAILED: [person] PERSON-TEST\n' in mail_text - assert '\nERROR: Object PERSON-TEST to be deleted, but still referenced by mntner TEST-MNT\n' in mail_text - assert '\nERROR: Object PERSON-TEST to be deleted, but still referenced by key-cert PGPKEY-80F238C6\n' in mail_text + assert messages[0]["Subject"] == "FAILED: my subject" + assert messages[0]["From"] == "from@example.com" + assert messages[0]["To"] == "Sasha " + assert "\nDelete FAILED: [person] PERSON-TEST\n" in mail_text + assert ( + "\nERROR: Object PERSON-TEST to be deleted, but still referenced by mntner TEST-MNT\n" + in mail_text + ) + assert ( + "\nERROR: Object PERSON-TEST to be deleted, but still referenced by key-cert PGPKEY-80F238C6\n" + in mail_text + ) # Object should not have changed by latest update. - person_text = whois_query('127.0.0.1', self.port_whois1, 'PERSON-TEST') - assert 'PERSON-TEST' in person_text - assert 'Test person changed by PGP signed update' in person_text + person_text = whois_query("127.0.0.1", self.port_whois1, "PERSON-TEST") + assert "PERSON-TEST" in person_text + assert "Test person changed by PGP signed update" in person_text # Submit a valid delete for all our new objects. - self._submit_update(self.config_path1, - f'{SAMPLE_PERSON}delete: delete\n\n{SAMPLE_KEY_CERT}delete: delete\n\n' + - f'{SAMPLE_MNTNER_CLEAN}delete: delete\npassword: crypt-password\n') + self._submit_update( + self.config_path1, + f"{SAMPLE_PERSON}delete: delete\n\n{SAMPLE_KEY_CERT}delete: delete\n\n" + + f"{SAMPLE_MNTNER_CLEAN}delete: delete\npassword: crypt-password\n", + ) messages = self._retrieve_mails() # Expected mails are status, mnt-nfy on mntner (2x), and notify on mntner # (notify on PERSON-TEST was removed in the PGP signed update) assert len(messages) == 4 mail_text = self._extract_message_body(messages[0]) - assert messages[0]['Subject'] == 'SUCCESS: my subject' - assert messages[0]['From'] == 'from@example.com' - assert messages[0]['To'] == 'Sasha ' - assert '\nDelete succeeded: [person] PERSON-TEST\n' in mail_text - assert '\nDelete succeeded: [mntner] TEST-MNT\n' in mail_text - assert '\nDelete succeeded: [key-cert] PGPKEY-80F238C6\n' in mail_text - - self._check_recipients_in_mails(messages[1:], [ - 'mnt-nfy@example.net', 'mnt-nfy2@example.net', 'notify@example.net', - ]) - - mnt_nfy_msgs = [msg for msg in messages if msg['To'] in ['mnt-nfy@example.net', 'mnt-nfy2@example.net']] - self._check_text_in_mails(mnt_nfy_msgs, [ - '\n> Message-ID: <1325754288.4989.6.camel@hostname>\n', - '\nDelete succeeded for object below: [person] PERSON-TEST:\n', - '\nDelete succeeded for object below: [mntner] TEST-MNT:\n', - '\nDelete succeeded for object below: [key-cert] PGPKEY-80F238C6:\n', - 'unįcöde tæst 🌈🦄\n', - # The object submitted to be deleted has the original name, - # but when sending delete notifications, they should include the - # object as currently in the DB, not as submitted in the email. - 'Test person changed by PGP signed update\n', - ]) + assert messages[0]["Subject"] == "SUCCESS: my subject" + assert messages[0]["From"] == "from@example.com" + assert messages[0]["To"] == "Sasha " + assert "\nDelete succeeded: [person] PERSON-TEST\n" in mail_text + assert "\nDelete succeeded: [mntner] TEST-MNT\n" in mail_text + assert "\nDelete succeeded: [key-cert] PGPKEY-80F238C6\n" in mail_text + + self._check_recipients_in_mails( + messages[1:], + [ + "mnt-nfy@example.net", + "mnt-nfy2@example.net", + "notify@example.net", + ], + ) + + mnt_nfy_msgs = [ + msg for msg in messages if msg["To"] in ["mnt-nfy@example.net", "mnt-nfy2@example.net"] + ] + self._check_text_in_mails( + mnt_nfy_msgs, + [ + "\n> Message-ID: <1325754288.4989.6.camel@hostname>\n", + "\nDelete succeeded for object below: [person] PERSON-TEST:\n", + "\nDelete succeeded for object below: [mntner] TEST-MNT:\n", + "\nDelete succeeded for object below: [key-cert] PGPKEY-80F238C6:\n", + "unįcöde tæst 🌈🦄\n", + # The object submitted to be deleted has the original name, + # but when sending delete notifications, they should include the + # object as currently in the DB, not as submitted in the email. + "Test person changed by PGP signed update\n", + ], + ) for message in messages[1:]: - assert message['Subject'] == 'Notification of TEST database changes' - assert message['From'] == 'from@example.com' + assert message["Subject"] == "Notification of TEST database changes" + assert message["From"] == "from@example.com" # Notify attribute mails are only about the objects concerned. - notify_msg = [msg for msg in messages if msg['To'] == 'notify@example.net'][0] + notify_msg = [msg for msg in messages if msg["To"] == "notify@example.net"][0] mail_text = self._extract_message_body(notify_msg) - assert notify_msg['Subject'] == 'Notification of TEST database changes' - assert notify_msg['From'] == 'from@example.com' - assert '\n> Message-ID: <1325754288.4989.6.camel@hostname>\n' in mail_text - assert '\nDelete succeeded for object below: [person] PERSON-TEST:\n' not in mail_text - assert '\nDelete succeeded for object below: [mntner] TEST-MNT:\n' in mail_text - assert '\nDelete succeeded for object below: [key-cert] PGPKEY-80F238C6:\n' not in mail_text + assert notify_msg["Subject"] == "Notification of TEST database changes" + assert notify_msg["From"] == "from@example.com" + assert "\n> Message-ID: <1325754288.4989.6.camel@hostname>\n" in mail_text + assert "\nDelete succeeded for object below: [person] PERSON-TEST:\n" not in mail_text + assert "\nDelete succeeded for object below: [mntner] TEST-MNT:\n" in mail_text + assert "\nDelete succeeded for object below: [key-cert] PGPKEY-80F238C6:\n" not in mail_text # Object should be deleted - person_text = whois_query('127.0.0.1', self.port_whois1, 'PERSON-TEST') - assert 'No entries found for the selected source(s)' in person_text - assert 'PERSON-TEST' not in person_text + person_text = whois_query("127.0.0.1", self.port_whois1, "PERSON-TEST") + assert "No entries found for the selected source(s)" in person_text + assert "PERSON-TEST" not in person_text # Object should be deleted from irrd #2 as well through NRTM. time.sleep(2) - person_text = whois_query('127.0.0.1', self.port_whois2, 'PERSON-TEST') - assert 'No entries found for the selected source(s)' in person_text - assert 'PERSON-TEST' not in person_text + person_text = whois_query("127.0.0.1", self.port_whois2, "PERSON-TEST") + assert "No entries found for the selected source(s)" in person_text + assert "PERSON-TEST" not in person_text # Load the mntner and person again, using the override password # Note that the route/route6 objects are RPKI valid on IRRd #1, # and RPKI-invalid on IRRd #2 - self._submit_update(self.config_path1, - SAMPLE_MNTNER_CLEAN + '\n\n' + SAMPLE_PERSON + '\n\noverride: override-password') + self._submit_update( + self.config_path1, + SAMPLE_MNTNER_CLEAN + "\n\n" + SAMPLE_PERSON + "\n\noverride: override-password", + ) messages = self._retrieve_mails() assert len(messages) == 1 mail_text = self._extract_message_body(messages[0]) - assert messages[0]['Subject'] == 'SUCCESS: my subject' - assert messages[0]['From'] == 'from@example.com' - assert messages[0]['To'] == 'Sasha ' - assert '\nCreate succeeded: [mntner] TEST-MNT\n' in mail_text - assert '\nCreate succeeded: [person] PERSON-TEST\n' in mail_text - assert 'email footer' in mail_text - assert 'Generated by IRRd version ' in mail_text + assert messages[0]["Subject"] == "SUCCESS: my subject" + assert messages[0]["From"] == "from@example.com" + assert messages[0]["To"] == "Sasha " + assert "\nCreate succeeded: [mntner] TEST-MNT\n" in mail_text + assert "\nCreate succeeded: [person] PERSON-TEST\n" in mail_text + assert "email footer" in mail_text + assert "Generated by IRRd version " in mail_text # Load samples of all known objects, using the mntner password - self._submit_update(self.config_path1, LARGE_UPDATE + '\n\npassword: md5-password') + self._submit_update(self.config_path1, LARGE_UPDATE + "\n\npassword: md5-password") messages = self._retrieve_mails() assert len(messages) == 3 mail_text = self._extract_message_body(messages[0]) - assert messages[0]['Subject'] == 'SUCCESS: my subject' - assert messages[0]['From'] == 'from@example.com' - assert messages[0]['To'] == 'Sasha ' - assert '\nINFO: AS number as065537 was reformatted as AS65537\n' in mail_text - assert '\nCreate succeeded: [filter-set] FLTR-SETTEST\n' in mail_text - assert '\nINFO: Address range 192.0.2.0 - 192.0.02.255 was reformatted as 192.0.2.0 - 192.0.2.255\n' in mail_text - assert '\nINFO: Address prefix 192.0.02.0/24 was reformatted as 192.0.2.0/24\n' in mail_text - assert '\nINFO: Route set member 2001:0dB8::/48 was reformatted as 2001:db8::/48\n' in mail_text + assert messages[0]["Subject"] == "SUCCESS: my subject" + assert messages[0]["From"] == "from@example.com" + assert messages[0]["To"] == "Sasha " + assert "\nINFO: AS number as065537 was reformatted as AS65537\n" in mail_text + assert "\nCreate succeeded: [filter-set] FLTR-SETTEST\n" in mail_text + assert ( + "\nINFO: Address range 192.0.2.0 - 192.0.02.255 was reformatted as 192.0.2.0 - 192.0.2.255\n" + in mail_text + ) + assert "\nINFO: Address prefix 192.0.02.0/24 was reformatted as 192.0.2.0/24\n" in mail_text + assert "\nINFO: Route set member 2001:0dB8::/48 was reformatted as 2001:db8::/48\n" in mail_text # Check whether the objects can be queried from irrd #1, # and whether the hash is masked. - mntner_text = whois_query('127.0.0.1', self.port_whois1, 'TEST-MNT') - assert 'TEST-MNT' in mntner_text + mntner_text = whois_query("127.0.0.1", self.port_whois1, "TEST-MNT") + assert "TEST-MNT" in mntner_text assert PASSWORD_HASH_DUMMY_VALUE in mntner_text - assert 'unįcöde tæst 🌈🦄' in mntner_text - assert 'PERSON-TEST' in mntner_text + assert "unįcöde tæst 🌈🦄" in mntner_text + assert "PERSON-TEST" in mntner_text # (This is the first instance of an object with unicode chars # appearing on the NRTM stream.) time.sleep(3) - mntner_text = whois_query('127.0.0.1', self.port_whois2, 'TEST-MNT') - assert 'TEST-MNT' in mntner_text + mntner_text = whois_query("127.0.0.1", self.port_whois2, "TEST-MNT") + assert "TEST-MNT" in mntner_text assert PASSWORD_HASH_DUMMY_VALUE in mntner_text - assert 'unįcöde tæst 🌈🦄' in mntner_text - assert 'PERSON-TEST' in mntner_text + assert "unįcöde tæst 🌈🦄" in mntner_text + assert "PERSON-TEST" in mntner_text # These queries have different responses on #1 than #2, # as all IPv4 routes are RPKI invalid on #2. - query_result = whois_query_irrd('127.0.0.1', self.port_whois1, '!gAS65537') - assert query_result == '192.0.2.0/24' - query_result = whois_query_irrd('127.0.0.1', self.port_whois1, '!gAS65547') - assert query_result == '192.0.2.0/32' # Pseudo-IRR object from RPKI - query_result = whois_query_irrd('127.0.0.1', self.port_whois1, '!6AS65537') - assert query_result == '2001:db8::/48' - query_result = whois_query_irrd('127.0.0.1', self.port_whois1, '!iRS-TEST') - assert set(query_result.split(' ')) == {'192.0.2.0/24', '2001:db8::/48', 'RS-OTHER-SET'} - query_result = whois_query_irrd('127.0.0.1', self.port_whois1, '!aAS65537:AS-SETTEST') - assert set(query_result.split(' ')) == {'192.0.2.0/24', '2001:db8::/48'} - query_result = whois_query_irrd('127.0.0.1', self.port_whois1, '!aAS65537:AS-TESTREF') - assert set(query_result.split(' ')) == {'192.0.2.0/24', '2001:db8::/48'} - query_result = whois_query_irrd('127.0.0.1', self.port_whois1, '!a4AS65537:AS-TESTREF') - assert query_result == '192.0.2.0/24' - query_result = whois_query_irrd('127.0.0.1', self.port_whois1, '!a6AS65537:AS-TESTREF') - assert query_result == '2001:db8::/48' - query_result = whois_query_irrd('127.0.0.1', self.port_whois1, '!r192.0.2.0/24') - assert 'example route' in query_result - query_result = whois_query_irrd('127.0.0.1', self.port_whois1, '!r192.0.2.0/25,l') - assert 'example route' in query_result - query_result = whois_query_irrd('127.0.0.1', self.port_whois1, '!r192.0.2.0/24,L') - assert 'example route' in query_result - query_result = whois_query_irrd('127.0.0.1', self.port_whois1, '!r192.0.2.0/23,M') - assert 'example route' in query_result - query_result = whois_query_irrd('127.0.0.1', self.port_whois1, '!r192.0.2.0/24,M') - assert 'RPKI' in query_result # Does not match the /24, does match the RPKI pseudo-IRR /32 - query_result = whois_query_irrd('127.0.0.1', self.port_whois1, '!r192.0.2.0/24,o') - assert query_result == 'AS65537' - query_result = whois_query('127.0.0.1', self.port_whois1, '-x 192.0.02.0/24') - assert 'example route' in query_result - query_result = whois_query('127.0.0.1', self.port_whois1, '-l 192.0.02.0/25') - assert 'example route' in query_result - query_result = whois_query('127.0.0.1', self.port_whois1, '-L 192.0.02.0/24') - assert 'example route' in query_result - query_result = whois_query('127.0.0.1', self.port_whois1, '-M 192.0.02.0/23') - assert 'example route' in query_result - query_result = whois_query('127.0.0.1', self.port_whois1, '-i member-of RS-test') - assert 'example route' in query_result - - query_result = whois_query_irrd('127.0.0.1', self.port_whois2, '!gAS65537') + query_result = whois_query_irrd("127.0.0.1", self.port_whois1, "!gAS65537") + assert query_result == "192.0.2.0/24" + query_result = whois_query_irrd("127.0.0.1", self.port_whois1, "!gAS65547") + assert query_result == "192.0.2.0/32" # Pseudo-IRR object from RPKI + query_result = whois_query_irrd("127.0.0.1", self.port_whois1, "!6AS65537") + assert query_result == "2001:db8::/48" + query_result = whois_query_irrd("127.0.0.1", self.port_whois1, "!iRS-TEST") + assert set(query_result.split(" ")) == {"192.0.2.0/24", "2001:db8::/48", "RS-OTHER-SET"} + query_result = whois_query_irrd("127.0.0.1", self.port_whois1, "!aAS65537:AS-SETTEST") + assert set(query_result.split(" ")) == {"192.0.2.0/24", "2001:db8::/48"} + query_result = whois_query_irrd("127.0.0.1", self.port_whois1, "!aAS65537:AS-TESTREF") + assert set(query_result.split(" ")) == {"192.0.2.0/24", "2001:db8::/48"} + query_result = whois_query_irrd("127.0.0.1", self.port_whois1, "!a4AS65537:AS-TESTREF") + assert query_result == "192.0.2.0/24" + query_result = whois_query_irrd("127.0.0.1", self.port_whois1, "!a6AS65537:AS-TESTREF") + assert query_result == "2001:db8::/48" + query_result = whois_query_irrd("127.0.0.1", self.port_whois1, "!r192.0.2.0/24") + assert "example route" in query_result + query_result = whois_query_irrd("127.0.0.1", self.port_whois1, "!r192.0.2.0/25,l") + assert "example route" in query_result + query_result = whois_query_irrd("127.0.0.1", self.port_whois1, "!r192.0.2.0/24,L") + assert "example route" in query_result + query_result = whois_query_irrd("127.0.0.1", self.port_whois1, "!r192.0.2.0/23,M") + assert "example route" in query_result + query_result = whois_query_irrd("127.0.0.1", self.port_whois1, "!r192.0.2.0/24,M") + assert "RPKI" in query_result # Does not match the /24, does match the RPKI pseudo-IRR /32 + query_result = whois_query_irrd("127.0.0.1", self.port_whois1, "!r192.0.2.0/24,o") + assert query_result == "AS65537" + query_result = whois_query("127.0.0.1", self.port_whois1, "-x 192.0.02.0/24") + assert "example route" in query_result + query_result = whois_query("127.0.0.1", self.port_whois1, "-l 192.0.02.0/25") + assert "example route" in query_result + query_result = whois_query("127.0.0.1", self.port_whois1, "-L 192.0.02.0/24") + assert "example route" in query_result + query_result = whois_query("127.0.0.1", self.port_whois1, "-M 192.0.02.0/23") + assert "example route" in query_result + query_result = whois_query("127.0.0.1", self.port_whois1, "-i member-of RS-test") + assert "example route" in query_result + + query_result = whois_query_irrd("127.0.0.1", self.port_whois2, "!gAS65537") assert not query_result - query_result = whois_query_irrd('127.0.0.1', self.port_whois2, '!6AS65537') - assert query_result == '2001:db8::/48' - query_result = whois_query_irrd('127.0.0.1', self.port_whois2, '!iRS-TEST') - assert query_result == '2001:db8::/48 RS-OTHER-SET' - query_result = whois_query_irrd('127.0.0.1', self.port_whois2, '!aAS65537:AS-SETTEST') - assert query_result == '2001:db8::/48' - query_result = whois_query_irrd('127.0.0.1', self.port_whois2, '!aAS65537:AS-TESTREF') - assert query_result == '2001:db8::/48' - query_result = whois_query('127.0.0.1', self.port_whois2, '-x 192.0.02.0/24') - assert 'example route' not in query_result - query_result = whois_query_irrd('127.0.0.1', self.port_whois2, '!r192.0.2.0/24,L') - assert 'RPKI' in query_result # Pseudo-IRR object 0/0 from RPKI + query_result = whois_query_irrd("127.0.0.1", self.port_whois2, "!6AS65537") + assert query_result == "2001:db8::/48" + query_result = whois_query_irrd("127.0.0.1", self.port_whois2, "!iRS-TEST") + assert query_result == "2001:db8::/48 RS-OTHER-SET" + query_result = whois_query_irrd("127.0.0.1", self.port_whois2, "!aAS65537:AS-SETTEST") + assert query_result == "2001:db8::/48" + query_result = whois_query_irrd("127.0.0.1", self.port_whois2, "!aAS65537:AS-TESTREF") + assert query_result == "2001:db8::/48" + query_result = whois_query("127.0.0.1", self.port_whois2, "-x 192.0.02.0/24") + assert "example route" not in query_result + query_result = whois_query_irrd("127.0.0.1", self.port_whois2, "!r192.0.2.0/24,L") + assert "RPKI" in query_result # Pseudo-IRR object 0/0 from RPKI # RPKI invalid object should not be in journal - query_result = whois_query('127.0.0.1', self.port_whois2, '-g TEST:3:1-LAST') - assert 'route:192.0.2.0/24' not in query_result.replace(' ', '') + query_result = whois_query("127.0.0.1", self.port_whois2, "-g TEST:3:1-LAST") + assert "route:192.0.2.0/24" not in query_result.replace(" ", "") # These queries should produce identical answers on both instances. for port in self.port_whois1, self.port_whois2: - query_result = whois_query_irrd('127.0.0.1', port, '!iAS65537:AS-SETTEST') - assert set(query_result.split(' ')) == {'AS65537', 'AS65538', 'AS65539', 'AS-OTHERSET'} - query_result = whois_query_irrd('127.0.0.1', port, '!iAS65537:AS-TESTREF') - assert set(query_result.split(' ')) == {'AS65537:AS-SETTEST', 'AS65540'} - query_result = whois_query_irrd('127.0.0.1', port, '!iAS65537:AS-TESTREF,1') - assert set(query_result.split(' ')) == {'AS65537', 'AS65538', 'AS65539', 'AS65540'} - query_result = whois_query_irrd('127.0.0.1', port, '!maut-num,as65537') - assert 'AS65537' in query_result - assert 'TEST-AS' in query_result - query_result = whois_query_irrd('127.0.0.1', port, '!oTEST-MNT') - assert 'AS65537' in query_result - assert 'TEST-AS' in query_result - assert 'AS65536 - AS65538' in query_result - assert 'rtrs-settest' in query_result - query_result = whois_query('127.0.0.1', port, '-T route6 -i member-of RS-TEST') - assert 'No entries found for the selected source(s)' in query_result - query_result = whois_query('127.0.0.1', port, 'dashcare') - assert 'ROLE-TEST' in query_result + query_result = whois_query_irrd("127.0.0.1", port, "!iAS65537:AS-SETTEST") + assert set(query_result.split(" ")) == {"AS65537", "AS65538", "AS65539", "AS-OTHERSET"} + query_result = whois_query_irrd("127.0.0.1", port, "!iAS65537:AS-TESTREF") + assert set(query_result.split(" ")) == {"AS65537:AS-SETTEST", "AS65540"} + query_result = whois_query_irrd("127.0.0.1", port, "!iAS65537:AS-TESTREF,1") + assert set(query_result.split(" ")) == {"AS65537", "AS65538", "AS65539", "AS65540"} + query_result = whois_query_irrd("127.0.0.1", port, "!maut-num,as65537") + assert "AS65537" in query_result + assert "TEST-AS" in query_result + query_result = whois_query_irrd("127.0.0.1", port, "!oTEST-MNT") + assert "AS65537" in query_result + assert "TEST-AS" in query_result + assert "AS65536 - AS65538" in query_result + assert "rtrs-settest" in query_result + query_result = whois_query("127.0.0.1", port, "-T route6 -i member-of RS-TEST") + assert "No entries found for the selected source(s)" in query_result + query_result = whois_query("127.0.0.1", port, "dashcare") + assert "ROLE-TEST" in query_result # Check the mirroring status - query_result = whois_query_irrd('127.0.0.1', self.port_whois1, '!J-*') + query_result = whois_query_irrd("127.0.0.1", self.port_whois1, "!J-*") result = ujson.loads(query_result) - assert result['TEST']['serial_newest_journal'] == 29 - assert result['TEST']['serial_last_export'] == 29 - assert result['TEST']['serial_newest_mirror'] is None + assert result["TEST"]["serial_newest_journal"] == 29 + assert result["TEST"]["serial_last_export"] == 29 + assert result["TEST"]["serial_newest_mirror"] is None # irrd #2 missed the first update from NRTM, as they were done at # the same time and loaded from the full export, and one RPKI-invalid object # was not recorded in the journal, so its local serial should # is lower by three - query_result = whois_query_irrd('127.0.0.1', self.port_whois2, '!J-*') + query_result = whois_query_irrd("127.0.0.1", self.port_whois2, "!J-*") result = ujson.loads(query_result) - assert result['TEST']['serial_newest_journal'] == 26 - assert result['TEST']['serial_last_export'] == 26 - assert result['TEST']['serial_newest_mirror'] == 29 + assert result["TEST"]["serial_newest_journal"] == 26 + assert result["TEST"]["serial_last_export"] == 26 + assert result["TEST"]["serial_newest_mirror"] == 29 # Make the v4 route in irrd2 valid - with open(self.roa_source2, 'w') as roa_file: - ujson.dump({'roas': [{'prefix': '198.51.100.0/24', 'asn': 'AS0', 'maxLength': '32', 'ta': 'TA'}]}, roa_file) + with open(self.roa_source2, "w") as roa_file: + ujson.dump( + {"roas": [{"prefix": "198.51.100.0/24", "asn": "AS0", "maxLength": "32", "ta": "TA"}]}, + roa_file, + ) time.sleep(3) - query_result = whois_query_irrd('127.0.0.1', self.port_whois2, '!gAS65537') - assert query_result == '192.0.2.0/24' + query_result = whois_query_irrd("127.0.0.1", self.port_whois2, "!gAS65537") + assert query_result == "192.0.2.0/24" # RPKI invalid object should now be added in the journal - query_result = whois_query('127.0.0.1', self.port_whois2, '-g TEST:3:27-27') - assert 'ADD 27' in query_result - assert '192.0.2.0/24' in query_result - query_result = whois_query_irrd('127.0.0.1', self.port_whois2, '!J-*') + query_result = whois_query("127.0.0.1", self.port_whois2, "-g TEST:3:27-27") + assert "ADD 27" in query_result + assert "192.0.2.0/24" in query_result + query_result = whois_query_irrd("127.0.0.1", self.port_whois2, "!J-*") result = ujson.loads(query_result) - assert result['TEST']['serial_newest_journal'] == 27 - assert result['TEST']['serial_last_export'] == 27 + assert result["TEST"]["serial_newest_journal"] == 27 + assert result["TEST"]["serial_last_export"] == 27 # This was a local journal update from RPKI status change, # so serial_newest_mirror did not update. - assert result['TEST']['serial_newest_mirror'] == 29 + assert result["TEST"]["serial_newest_mirror"] == 29 # Make the v4 route in irrd2 invalid again - with open(self.roa_source2, 'w') as roa_file: - ujson.dump({'roas': [{'prefix': '128/1', 'asn': 'AS0', 'maxLength': '32', 'ta': 'TA'}]}, roa_file) + with open(self.roa_source2, "w") as roa_file: + ujson.dump({"roas": [{"prefix": "128/1", "asn": "AS0", "maxLength": "32", "ta": "TA"}]}, roa_file) time.sleep(3) - query_result = whois_query_irrd('127.0.0.1', self.port_whois2, '!gAS65537') + query_result = whois_query_irrd("127.0.0.1", self.port_whois2, "!gAS65537") assert not query_result # RPKI invalid object should now be deleted in the journal - query_result = whois_query('127.0.0.1', self.port_whois2, '-g TEST:3:28-28') - assert 'DEL 28' in query_result - assert '192.0.2.0/24' in query_result - query_result = whois_query_irrd('127.0.0.1', self.port_whois2, '!J-*') + query_result = whois_query("127.0.0.1", self.port_whois2, "-g TEST:3:28-28") + assert "DEL 28" in query_result + assert "192.0.2.0/24" in query_result + query_result = whois_query_irrd("127.0.0.1", self.port_whois2, "!J-*") result = ujson.loads(query_result) - assert result['TEST']['serial_newest_journal'] == 28 - assert result['TEST']['serial_last_export'] == 28 - assert result['TEST']['serial_newest_mirror'] == 29 + assert result["TEST"]["serial_newest_journal"] == 28 + assert result["TEST"]["serial_last_export"] == 28 + assert result["TEST"]["serial_newest_mirror"] == 29 # Make the v4 route in irrd1 invalid, triggering a mail - with open(self.roa_source1, 'w') as roa_file: - ujson.dump({'roas': [{'prefix': '128/1', 'asn': 'AS0', 'maxLength': '32', 'ta': 'TA'}]}, roa_file) + with open(self.roa_source1, "w") as roa_file: + ujson.dump({"roas": [{"prefix": "128/1", "asn": "AS0", "maxLength": "32", "ta": "TA"}]}, roa_file) # irrd1 is authoritative for the now invalid v4 route, should have sent mail time.sleep(2) messages = self._retrieve_mails() assert len(messages) == 3 mail_text = self._extract_message_body(messages[0]) - assert messages[0]['Subject'] == 'route(6) objects in TEST marked RPKI invalid' - expected_recipients = {'email@example.com', 'mnt-nfy@example.net', 'mnt-nfy2@example.net'} - assert {m['To'] for m in messages} == expected_recipients - assert '192.0.2.0/24' in mail_text + assert messages[0]["Subject"] == "route(6) objects in TEST marked RPKI invalid" + expected_recipients = {"email@example.com", "mnt-nfy@example.net", "mnt-nfy2@example.net"} + assert {m["To"] for m in messages} == expected_recipients + assert "192.0.2.0/24" in mail_text self.check_http() self.check_graphql() def check_http(self): - status1 = requests.get(f'http://127.0.0.1:{self.port_http1}/v1/status/') - status2 = requests.get(f'http://127.0.0.1:{self.port_http2}/v1/status/') + status1 = requests.get(f"http://127.0.0.1:{self.port_http1}/v1/status/") + status2 = requests.get(f"http://127.0.0.1:{self.port_http2}/v1/status/") assert status1.status_code == 200 assert status2.status_code == 200 - assert 'IRRD version' in status1.text - assert 'IRRD version' in status2.text - assert 'TEST' in status1.text - assert 'TEST' in status2.text - assert 'RPKI' in status1.text - assert 'RPKI' in status2.text - assert 'Authoritative: Yes' in status1.text - assert 'Authoritative: Yes' not in status2.text + assert "IRRD version" in status1.text + assert "IRRD version" in status2.text + assert "TEST" in status1.text + assert "TEST" in status2.text + assert "RPKI" in status1.text + assert "RPKI" in status2.text + assert "Authoritative: Yes" in status1.text + assert "Authoritative: Yes" not in status2.text def check_graphql(self): client = GraphqlClient(endpoint=f"http://127.0.0.1:{self.port_http1}/graphql/") @@ -549,17 +624,19 @@ def check_graphql(self): } """ result = client.execute(query=query) - assert result['data']['rpslObjects'] == [{ - 'rpslPk': 'PERSON-TEST', - 'mntBy': ['TEST-MNT'], - 'mntByObjs': [{'rpslPk': 'TEST-MNT', 'adminCObjs': [{'rpslPk': 'PERSON-TEST'}]}], - 'journal': [ - {'serialNrtm': 2, 'operation': 'add_or_update', 'origin': 'auth_change'}, - {'serialNrtm': 4, 'operation': 'add_or_update', 'origin': 'auth_change'}, - {'serialNrtm': 5, 'operation': 'delete', 'origin': 'auth_change'}, - {'serialNrtm': 9, 'operation': 'add_or_update', 'origin': 'auth_change'} - ] - }] + assert result["data"]["rpslObjects"] == [ + { + "rpslPk": "PERSON-TEST", + "mntBy": ["TEST-MNT"], + "mntByObjs": [{"rpslPk": "TEST-MNT", "adminCObjs": [{"rpslPk": "PERSON-TEST"}]}], + "journal": [ + {"serialNrtm": 2, "operation": "add_or_update", "origin": "auth_change"}, + {"serialNrtm": 4, "operation": "add_or_update", "origin": "auth_change"}, + {"serialNrtm": 5, "operation": "delete", "origin": "auth_change"}, + {"serialNrtm": 9, "operation": "add_or_update", "origin": "auth_change"}, + ], + } + ] # Test memberOfObjs resolving and IP search query = """query { @@ -574,10 +651,13 @@ def check_graphql(self): } """ result = client.execute(query=query) - self.assertCountEqual(result['data']['rpslObjects'], [ - {'rpslPk': '192.0.2.0/24AS65537', 'memberOfObjs': [{'rpslPk': 'RS-TEST'}]}, - {'rpslPk': '192.0.2.0 - 192.0.2.255'} - ]) + self.assertCountEqual( + result["data"]["rpslObjects"], + [ + {"rpslPk": "192.0.2.0/24AS65537", "memberOfObjs": [{"rpslPk": "RS-TEST"}]}, + {"rpslPk": "192.0.2.0 - 192.0.2.255"}, + ], + ) # Test membersObjs and mbrsByRefObjs resolving query = """query { @@ -595,11 +675,13 @@ def check_graphql(self): } """ result = client.execute(query=query) - assert result['data']['rpslObjects'] == [{ - 'rpslPk': 'AS65537:AS-TESTREF', - 'membersObjs': [{'rpslPk': 'AS65537:AS-SETTEST'}], - 'mbrsByRefObjs': [{'rpslPk': 'TEST-MNT'}], - }] + assert result["data"]["rpslObjects"] == [ + { + "rpslPk": "AS65537:AS-TESTREF", + "membersObjs": [{"rpslPk": "AS65537:AS-SETTEST"}], + "mbrsByRefObjs": [{"rpslPk": "TEST-MNT"}], + } + ] # Test databaseStatus query query = """query { @@ -613,21 +695,25 @@ def check_graphql(self): } """ result = client.execute(query=query) - self.assertCountEqual(result['data']['databaseStatus'], [ - { - 'source': 'TEST', - 'authoritative': True, - 'serialOldestJournal': 1, - 'serialNewestJournal': 30, - 'serialNewestMirror': None - }, { - 'source': 'RPKI', - 'authoritative': False, - 'serialOldestJournal': None, - 'serialNewestJournal': None, - 'serialNewestMirror': None - } - ]) + self.assertCountEqual( + result["data"]["databaseStatus"], + [ + { + "source": "TEST", + "authoritative": True, + "serialOldestJournal": 1, + "serialNewestJournal": 30, + "serialNewestMirror": None, + }, + { + "source": "RPKI", + "authoritative": False, + "serialOldestJournal": None, + "serialNewestJournal": None, + "serialNewestMirror": None, + }, + ], + ) # Test asnPrefixes query query = """query { @@ -638,10 +724,10 @@ def check_graphql(self): } """ result = client.execute(query=query) - asnPrefixes = result['data']['asnPrefixes'] + asnPrefixes = result["data"]["asnPrefixes"] assert len(asnPrefixes) == 1 - assert asnPrefixes[0]['asn'] == 65537 - assert set(asnPrefixes[0]['prefixes']) == {'2001:db8::/48'} + assert asnPrefixes[0]["asn"] == 65537 + assert set(asnPrefixes[0]["prefixes"]) == {"2001:db8::/48"} # Test asSetPrefixes query query = """query { @@ -652,10 +738,10 @@ def check_graphql(self): } """ result = client.execute(query=query) - asSetPrefixes = result['data']['asSetPrefixes'] + asSetPrefixes = result["data"]["asSetPrefixes"] assert len(asSetPrefixes) == 1 - assert asSetPrefixes[0]['rpslPk'] == 'AS65537:AS-TESTREF' - assert set(asSetPrefixes[0]['prefixes']) == {'2001:db8::/48'} + assert asSetPrefixes[0]["rpslPk"] == "AS65537:AS-TESTREF" + assert set(asSetPrefixes[0]["prefixes"]) == {"2001:db8::/48"} # Test recursiveSetMembers query query = """query { @@ -667,13 +753,11 @@ def check_graphql(self): } """ result = client.execute(query=query) - recursiveSetMembers = result['data']['recursiveSetMembers'] + recursiveSetMembers = result["data"]["recursiveSetMembers"] assert len(recursiveSetMembers) == 1 - assert recursiveSetMembers[0]['rpslPk'] == 'AS65537:AS-TESTREF' - assert recursiveSetMembers[0]['rootSource'] == 'TEST' - assert set(recursiveSetMembers[0]['members']) == { - 'AS65537', 'AS65538', 'AS65539', 'AS65540' - } + assert recursiveSetMembers[0]["rpslPk"] == "AS65537:AS-TESTREF" + assert recursiveSetMembers[0]["rootSource"] == "TEST" + assert set(recursiveSetMembers[0]["members"]) == {"AS65537", "AS65538", "AS65539", "AS65540"} def _start_mailserver(self): """ @@ -682,11 +766,18 @@ def _start_mailserver(self): It keeps mails in memory, and _retrieve_mails() can retrieve them using special SMTP commands. """ - self.pidfile_mailserver = str(self.tmpdir) + '/mailserver.pid' - self.logfile_mailserver = str(self.tmpdir) + '/mailserver.log' - mailserver_path = IRRD_ROOT_PATH + '/irrd/integration_tests/mailserver.tac' - assert not subprocess.call(['twistd', f'--pidfile={self.pidfile_mailserver}', - f'--logfile={self.logfile_mailserver}', '-y', mailserver_path]) + self.pidfile_mailserver = str(self.tmpdir) + "/mailserver.pid" + self.logfile_mailserver = str(self.tmpdir) + "/mailserver.log" + mailserver_path = IRRD_ROOT_PATH + "/irrd/integration_tests/mailserver.tac" + assert not subprocess.call( + [ + "twistd", + f"--pidfile={self.pidfile_mailserver}", + f"--logfile={self.logfile_mailserver}", + "-y", + mailserver_path, + ] + ) # noinspection PyTypeChecker def _start_irrds(self): @@ -695,29 +786,31 @@ def _start_irrds(self): IRRd #1 has an authoritative database, IRRd #2 mirrors that database from #1. """ - self.database_url1 = os.environ['IRRD_DATABASE_URL_INTEGRATION_1'] - self.database_url2 = os.environ['IRRD_DATABASE_URL_INTEGRATION_2'] - self.redis_url1 = os.environ['IRRD_REDIS_URL_INTEGRATION_1'] - self.redis_url2 = os.environ['IRRD_REDIS_URL_INTEGRATION_2'] - - self.config_path1 = str(self.tmpdir) + '/irrd1_config.yaml' - self.config_path2 = str(self.tmpdir) + '/irrd2_config.yaml' - self.logfile1 = str(self.tmpdir) + '/irrd1.log' - self.logfile2 = str(self.tmpdir) + '/irrd2.log' - self.roa_source1 = str(self.tmpdir) + '/roa1.json' - self.roa_source2 = str(self.tmpdir) + '/roa2.json' - self.export_dir1 = str(self.tmpdir) + '/export1/' - self.export_dir2 = str(self.tmpdir) + '/export2/' - self.piddir1 = str(self.tmpdir) + '/piddir1/' - self.piddir2 = str(self.tmpdir) + '/piddir2/' - self.pidfile1 = self.piddir1 + 'irrd.pid' - self.pidfile2 = self.piddir2 + 'irrd.pid' + self.database_url1 = os.environ["IRRD_DATABASE_URL_INTEGRATION_1"] + self.database_url2 = os.environ["IRRD_DATABASE_URL_INTEGRATION_2"] + self.redis_url1 = os.environ["IRRD_REDIS_URL_INTEGRATION_1"] + self.redis_url2 = os.environ["IRRD_REDIS_URL_INTEGRATION_2"] + + self.config_path1 = str(self.tmpdir) + "/irrd1_config.yaml" + self.config_path2 = str(self.tmpdir) + "/irrd2_config.yaml" + self.logfile1 = str(self.tmpdir) + "/irrd1.log" + self.logfile2 = str(self.tmpdir) + "/irrd2.log" + self.roa_source1 = str(self.tmpdir) + "/roa1.json" + self.roa_source2 = str(self.tmpdir) + "/roa2.json" + self.export_dir1 = str(self.tmpdir) + "/export1/" + self.export_dir2 = str(self.tmpdir) + "/export2/" + self.piddir1 = str(self.tmpdir) + "/piddir1/" + self.piddir2 = str(self.tmpdir) + "/piddir2/" + self.pidfile1 = self.piddir1 + "irrd.pid" + self.pidfile2 = self.piddir2 + "irrd.pid" os.mkdir(self.export_dir1) os.mkdir(self.export_dir2) os.mkdir(self.piddir1) os.mkdir(self.piddir2) - print(textwrap.dedent(f""" + print( + textwrap.dedent( + f""" Preparing to start IRRd for integration test. IRRd #1 running on HTTP port {self.port_http1}, whois port {self.port_whois1} @@ -731,123 +824,111 @@ def _start_irrds(self): Database URL: {self.database_url2} PID file: {self.pidfile2} Logfile: {self.logfile2} - """)) + """ + ) + ) - with open(self.roa_source1, 'w') as roa_file: - ujson.dump({'roas': [{'prefix': '192.0.2.0/32', 'asn': 'AS65547', 'maxLength': '32', 'ta': 'TA'}]}, roa_file) - with open(self.roa_source2, 'w') as roa_file: - ujson.dump({'roas': [{'prefix': '128/1', 'asn': 'AS0', 'maxLength': '1', 'ta': 'TA'}]}, roa_file) + with open(self.roa_source1, "w") as roa_file: + ujson.dump( + {"roas": [{"prefix": "192.0.2.0/32", "asn": "AS65547", "maxLength": "32", "ta": "TA"}]}, + roa_file, + ) + with open(self.roa_source2, "w") as roa_file: + ujson.dump({"roas": [{"prefix": "128/1", "asn": "AS0", "maxLength": "1", "ta": "TA"}]}, roa_file) base_config = { - 'irrd': { - 'access_lists': { - 'localhost': ['::/32', '127.0.0.1'] - }, - - 'server': { - 'http': { - 'status_access_list': 'localhost', - 'interface': '::1', - 'port': 8080 - }, - 'whois': { - 'interface': '::1', - 'max_connections': 10, - 'port': 8043 - }, + "irrd": { + "access_lists": {"localhost": ["::/32", "127.0.0.1"]}, + "server": { + "http": {"status_access_list": "localhost", "interface": "::1", "port": 8080}, + "whois": {"interface": "::1", "max_connections": 10, "port": 8043}, }, - - 'rpki':{ - 'roa_import_timer': 1, - 'notify_invalid_enabled': True, + "rpki": { + "roa_import_timer": 1, + "notify_invalid_enabled": True, }, - - 'auth': { - 'gnupg_keyring': None, - 'override_password': '$1$J6KycItM$MbPaBU6iFSGFV299Rk7Di0', - 'set_creation': { - 'filter-set': { - 'prefix_required': False, + "auth": { + "gnupg_keyring": None, + "override_password": "$1$J6KycItM$MbPaBU6iFSGFV299Rk7Di0", + "set_creation": { + "filter-set": { + "prefix_required": False, }, - 'peering-set': { - 'prefix_required': False, + "peering-set": { + "prefix_required": False, }, - 'route-set': { - 'prefix_required': False, + "route-set": { + "prefix_required": False, }, - 'rtr-set': { - 'prefix_required': False, + "rtr-set": { + "prefix_required": False, }, }, - 'password_hashers': { - 'crypt-pw': 'enabled', - } - + "password_hashers": { + "crypt-pw": "enabled", + }, }, - - 'email': { - 'footer': 'email footer', - 'from': 'from@example.com', - 'smtp': f'localhost:{EMAIL_SMTP_PORT}', + "email": { + "footer": "email footer", + "from": "from@example.com", + "smtp": f"localhost:{EMAIL_SMTP_PORT}", }, - - 'log': { - 'logfile_path': None, - 'level': 'DEBUG', + "log": { + "logfile_path": None, + "level": "DEBUG", }, - - 'sources': {} + "sources": {}, } } config1 = base_config.copy() - config1['irrd']['piddir'] = self.piddir1 - config1['irrd']['database_url'] = self.database_url1 - config1['irrd']['redis_url'] = self.redis_url1 - config1['irrd']['server']['http']['interface'] = '127.0.0.1' # #306 - config1['irrd']['server']['http']['port'] = self.port_http1 - config1['irrd']['server']['whois']['interface'] = '127.0.0.1' - config1['irrd']['server']['whois']['port'] = self.port_whois1 - config1['irrd']['auth']['gnupg_keyring'] = str(self.tmpdir) + '/gnupg1' - config1['irrd']['log']['logfile_path'] = self.logfile1 - config1['irrd']['rpki']['roa_source'] = 'file://' + self.roa_source1 - config1['irrd']['sources']['TEST'] = { - 'authoritative': True, - 'keep_journal': True, - 'export_destination': self.export_dir1, - 'export_timer': '1', - 'nrtm_access_list': 'localhost', + config1["irrd"]["piddir"] = self.piddir1 + config1["irrd"]["database_url"] = self.database_url1 + config1["irrd"]["redis_url"] = self.redis_url1 + config1["irrd"]["server"]["http"]["interface"] = "127.0.0.1" # #306 + config1["irrd"]["server"]["http"]["port"] = self.port_http1 + config1["irrd"]["server"]["whois"]["interface"] = "127.0.0.1" + config1["irrd"]["server"]["whois"]["port"] = self.port_whois1 + config1["irrd"]["auth"]["gnupg_keyring"] = str(self.tmpdir) + "/gnupg1" + config1["irrd"]["log"]["logfile_path"] = self.logfile1 + config1["irrd"]["rpki"]["roa_source"] = "file://" + self.roa_source1 + config1["irrd"]["sources"]["TEST"] = { + "authoritative": True, + "keep_journal": True, + "export_destination": self.export_dir1, + "export_timer": "1", + "nrtm_access_list": "localhost", } - with open(self.config_path1, 'w') as yaml_file: + with open(self.config_path1, "w") as yaml_file: yaml.safe_dump(config1, yaml_file) config2 = base_config.copy() - config2['irrd']['piddir'] = self.piddir2 - config2['irrd']['database_url'] = self.database_url2 - config2['irrd']['redis_url'] = self.redis_url2 - config2['irrd']['server']['http']['port'] = self.port_http2 - config2['irrd']['server']['whois']['port'] = self.port_whois2 - config2['irrd']['auth']['gnupg_keyring'] = str(self.tmpdir) + '/gnupg2' - config2['irrd']['log']['logfile_path'] = self.logfile2 - config2['irrd']['rpki']['roa_source'] = 'file://' + self.roa_source2 - config2['irrd']['sources']['TEST'] = { - 'keep_journal': True, - 'import_serial_source': f'file://{self.export_dir1}/TEST.CURRENTSERIAL', - 'import_source': f'file://{self.export_dir1}/test.db.gz', - 'export_destination': self.export_dir2, - 'import_timer': '1', - 'export_timer': '1', - 'nrtm_host': '127.0.0.1', - 'nrtm_port': str(self.port_whois1), - 'nrtm_access_list': 'localhost', + config2["irrd"]["piddir"] = self.piddir2 + config2["irrd"]["database_url"] = self.database_url2 + config2["irrd"]["redis_url"] = self.redis_url2 + config2["irrd"]["server"]["http"]["port"] = self.port_http2 + config2["irrd"]["server"]["whois"]["port"] = self.port_whois2 + config2["irrd"]["auth"]["gnupg_keyring"] = str(self.tmpdir) + "/gnupg2" + config2["irrd"]["log"]["logfile_path"] = self.logfile2 + config2["irrd"]["rpki"]["roa_source"] = "file://" + self.roa_source2 + config2["irrd"]["sources"]["TEST"] = { + "keep_journal": True, + "import_serial_source": f"file://{self.export_dir1}/TEST.CURRENTSERIAL", + "import_source": f"file://{self.export_dir1}/test.db.gz", + "export_destination": self.export_dir2, + "import_timer": "1", + "export_timer": "1", + "nrtm_host": "127.0.0.1", + "nrtm_port": str(self.port_whois1), + "nrtm_access_list": "localhost", } - with open(self.config_path2, 'w') as yaml_file: + with open(self.config_path2, "w") as yaml_file: yaml.safe_dump(config2, yaml_file) self._prepare_database() - assert not subprocess.call(['irrd/daemon/main.py', f'--config={self.config_path1}']) - assert not subprocess.call(['irrd/daemon/main.py', f'--config={self.config_path2}']) + assert not subprocess.call(["irrd/daemon/main.py", f"--config={self.config_path1}"]) + assert not subprocess.call(["irrd/daemon/main.py", f"--config={self.config_path2}"]) def _prepare_database(self): """ @@ -856,25 +937,25 @@ def _prepare_database(self): """ config_init(self.config_path1) alembic_cfg = config.Config() - alembic_cfg.set_main_option('script_location', f'{IRRD_ROOT_PATH}/irrd/storage/alembic') - command.upgrade(alembic_cfg, 'head') + alembic_cfg.set_main_option("script_location", f"{IRRD_ROOT_PATH}/irrd/storage/alembic") + command.upgrade(alembic_cfg, "head") connection = sa.create_engine(translate_url(self.database_url1)).connect() - connection.execute('DELETE FROM rpsl_objects') - connection.execute('DELETE FROM rpsl_database_journal') - connection.execute('DELETE FROM database_status') - connection.execute('DELETE FROM roa_object') + connection.execute("DELETE FROM rpsl_objects") + connection.execute("DELETE FROM rpsl_database_journal") + connection.execute("DELETE FROM database_status") + connection.execute("DELETE FROM roa_object") config_init(self.config_path2) alembic_cfg = config.Config() - alembic_cfg.set_main_option('script_location', f'{IRRD_ROOT_PATH}/irrd/storage/alembic') - command.upgrade(alembic_cfg, 'head') + alembic_cfg.set_main_option("script_location", f"{IRRD_ROOT_PATH}/irrd/storage/alembic") + command.upgrade(alembic_cfg, "head") connection = sa.create_engine(translate_url(self.database_url2)).connect() - connection.execute('DELETE FROM rpsl_objects') - connection.execute('DELETE FROM rpsl_database_journal') - connection.execute('DELETE FROM database_status') - connection.execute('DELETE FROM roa_object') + connection.execute("DELETE FROM rpsl_objects") + connection.execute("DELETE FROM rpsl_database_journal") + connection.execute("DELETE FROM database_status") + connection.execute("DELETE FROM roa_object") def _submit_update(self, config_path, request): """ @@ -882,7 +963,9 @@ def _submit_update(self, config_path, request): with a specific config path. Request is the raw RPSL update, possibly signed with inline PGP. """ - email = textwrap.dedent(""" + email = ( + textwrap.dedent( + """ From submitter@example.com@localhost Thu Jan 5 10:04:48 2018 Received: from [127.0.0.1] (localhost.localdomain [127.0.0.1]) by hostname (Postfix) with ESMTPS id 740AD310597 @@ -898,12 +981,15 @@ def _submit_update(self, config_path, request): Content-Type: text/plain; charset=utf-8 Mime-Version: 1.0 - """).lstrip().encode('utf-8') - email += base64.b64encode(request.encode('utf-8')) - - script = IRRD_ROOT_PATH + '/irrd/scripts/submit_email.py' - p = subprocess.Popen([script, f'--config={config_path}'], - stdin=subprocess.PIPE) + """ + ) + .lstrip() + .encode("utf-8") + ) + email += base64.b64encode(request.encode("utf-8")) + + script = IRRD_ROOT_PATH + "/irrd/scripts/submit_email.py" + p = subprocess.Popen([script, f"--config={config_path}"], stdin=subprocess.PIPE) p.communicate(email) p.wait() @@ -915,19 +1001,22 @@ def _retrieve_mails(self): """ s = socket.socket() s.settimeout(5) - s.connect(('localhost', EMAIL_SMTP_PORT)) + s.connect(("localhost", EMAIL_SMTP_PORT)) - s.sendall(f'{EMAIL_RETURN_MSGS_COMMAND}\r\n'.encode('ascii')) + s.sendall(f"{EMAIL_RETURN_MSGS_COMMAND}\r\n".encode("ascii")) - buffer = b'' + buffer = b"" while EMAIL_END not in buffer: data = s.recv(1024 * 1024) buffer += data - buffer = buffer.split(b'\n', 1)[1] + buffer = buffer.split(b"\n", 1)[1] buffer = buffer.split(EMAIL_END, 1)[0] - s.sendall(f'{EMAIL_DISCARD_MSGS_COMMAND}\r\n'.encode('ascii')) - messages = [email.message_from_string(m.strip().decode('ascii')) for m in buffer.split(EMAIL_SEPARATOR.encode('ascii'))] + s.sendall(f"{EMAIL_DISCARD_MSGS_COMMAND}\r\n".encode("ascii")) + messages = [ + email.message_from_string(m.strip().decode("ascii")) + for m in buffer.split(EMAIL_SEPARATOR.encode("ascii")) + ] return messages def _extract_message_body(self, message): @@ -935,8 +1024,8 @@ def _extract_message_body(self, message): Convenience method to extract the main body from a non-multipart email.Message object. """ - charset = message.get_content_charset(failobj='ascii') - return message.get_payload(decode=True).decode(charset, 'backslashreplace') # type: ignore + charset = message.get_content_charset(failobj="ascii") + return message.get_payload(decode=True).decode(charset, "backslashreplace") # type: ignore def _check_text_in_mails(self, messages, expected_texts): """ @@ -946,7 +1035,7 @@ def _check_text_in_mails(self, messages, expected_texts): for expected_text in expected_texts: for message in messages: message_text = self._extract_message_body(message) - assert expected_text in message_text, f'Missing text {expected_text} in mail:\n{message_text}' + assert expected_text in message_text, f"Missing text {expected_text} in mail:\n{message_text}" def _check_recipients_in_mails(self, messages, expected_recipients): """ @@ -961,7 +1050,7 @@ def _check_recipients_in_mails(self, messages, expected_recipients): leftover_expected_recipients = original_expected_recipients.copy() for message in messages: for recipient in original_expected_recipients: - if message['To'] == recipient: + if message["To"] == recipient: leftover_expected_recipients.remove(recipient) assert not leftover_expected_recipients @@ -971,13 +1060,13 @@ def teardown_method(self, method): or not they succeed. It is used to kill any leftover IRRd or SMTP server processes. """ - print('\n') + print("\n") for pidfile in self.pidfile1, self.pidfile2, self.pidfile_mailserver: try: with open(pidfile) as fh: pid = int(fh.read()) - print(f'Terminating PID {pid} from {pidfile}') + print(f"Terminating PID {pid} from {pidfile}") os.kill(pid, signal.SIGTERM) except (FileNotFoundError, ProcessLookupError, ValueError) as exc: - print(f'Failed to kill: {pidfile}: {exc}') + print(f"Failed to kill: {pidfile}: {exc}") pass diff --git a/irrd/mirroring/mirror_runners_export.py b/irrd/mirroring/mirror_runners_export.py index bdc8c0759..bb707cf3c 100644 --- a/irrd/mirroring/mirror_runners_export.py +++ b/irrd/mirroring/mirror_runners_export.py @@ -1,7 +1,6 @@ -import os - import gzip import logging +import os import shutil from pathlib import Path from tempfile import NamedTemporaryFile @@ -11,7 +10,7 @@ from irrd.rpki.status import RPKIStatus from irrd.scopefilter.status import ScopeFilterStatus from irrd.storage.database_handler import DatabaseHandler -from irrd.storage.queries import RPSLDatabaseQuery, DatabaseStatusQuery +from irrd.storage.queries import DatabaseStatusQuery, RPSLDatabaseQuery from irrd.utils.text import remove_auth_hashes as remove_auth_hashes_func EXPORT_PERMISSIONS = 0o644 @@ -37,48 +36,54 @@ def __init__(self, source: str) -> None: def run(self) -> None: self.database_handler = DatabaseHandler() try: - export_destination = get_setting(f'sources.{self.source}.export_destination') + export_destination = get_setting(f"sources.{self.source}.export_destination") if export_destination: - logger.info(f'Starting a source export for {self.source} to {export_destination}') + logger.info(f"Starting a source export for {self.source} to {export_destination}") self._export(export_destination) - export_destination_unfiltered = get_setting(f'sources.{self.source}.export_destination_unfiltered') + export_destination_unfiltered = get_setting( + f"sources.{self.source}.export_destination_unfiltered" + ) if export_destination_unfiltered: - logger.info(f'Starting an unfiltered source export for {self.source} ' - f'to {export_destination_unfiltered}') + logger.info( + f"Starting an unfiltered source export for {self.source} " + f"to {export_destination_unfiltered}" + ) self._export(export_destination_unfiltered, remove_auth_hashes=False) self.database_handler.commit() except Exception as exc: - logger.error(f'An exception occurred while attempting to run an export ' - f'for {self.source}: {exc}', exc_info=exc) + logger.error( + f"An exception occurred while attempting to run an export for {self.source}: {exc}", + exc_info=exc, + ) finally: self.database_handler.close() def _export(self, export_destination, remove_auth_hashes=True): - filename_export = Path(export_destination) / f'{self.source.lower()}.db.gz' + filename_export = Path(export_destination) / f"{self.source.lower()}.db.gz" export_tmpfile = NamedTemporaryFile(delete=False) - filename_serial = Path(export_destination) / f'{self.source.upper()}.CURRENTSERIAL' + filename_serial = Path(export_destination) / f"{self.source.upper()}.CURRENTSERIAL" query = DatabaseStatusQuery().source(self.source) try: - serial = next(self.database_handler.execute_query(query))['serial_newest_seen'] + serial = next(self.database_handler.execute_query(query))["serial_newest_seen"] except StopIteration: serial = None - with gzip.open(export_tmpfile.name, 'wb') as fh: + with gzip.open(export_tmpfile.name, "wb") as fh: query = RPSLDatabaseQuery().sources([self.source]) query = query.rpki_status([RPKIStatus.not_found, RPKIStatus.valid]) query = query.scopefilter_status([ScopeFilterStatus.in_scope]) query = query.route_preference_status([RoutePreferenceStatus.visible]) for obj in self.database_handler.execute_query(query): - object_text = obj['object_text'] + object_text = obj["object_text"] if remove_auth_hashes: object_text = remove_auth_hashes_func(object_text) - object_bytes = object_text.encode('utf-8') - fh.write(object_bytes + b'\n') - fh.write(b'# EOF\n') + object_bytes = object_text.encode("utf-8") + fh.write(object_bytes + b"\n") + fh.write(b"# EOF\n") os.chmod(export_tmpfile.name, EXPORT_PERMISSIONS) if filename_export.exists(): @@ -88,9 +93,12 @@ def _export(self, export_destination, remove_auth_hashes=True): shutil.move(export_tmpfile.name, filename_export) if serial is not None: - with open(filename_serial, 'w') as fh: + with open(filename_serial, "w") as fh: fh.write(str(serial)) os.chmod(filename_serial, EXPORT_PERMISSIONS) self.database_handler.record_serial_exported(self.source, serial) - logger.info(f'Export for {self.source} complete at serial {serial}, stored in {filename_export} / {filename_serial}') + logger.info( + f"Export for {self.source} complete at serial {serial}, stored in {filename_export} /" + f" {filename_serial}" + ) diff --git a/irrd/mirroring/mirror_runners_import.py b/irrd/mirroring/mirror_runners_import.py index b037f8ccf..97f7cf39e 100644 --- a/irrd/mirroring/mirror_runners_import.py +++ b/irrd/mirroring/mirror_runners_import.py @@ -4,24 +4,25 @@ import shutil from io import BytesIO from tempfile import NamedTemporaryFile -from typing import Optional, Tuple, Any, IO +from typing import IO, Any, Optional, Tuple from urllib import request -from urllib.parse import urlparse from urllib.error import URLError +from urllib.parse import urlparse import requests -from irrd.conf import get_setting, RPKI_IRR_PSEUDO_SOURCE +from irrd.conf import RPKI_IRR_PSEUDO_SOURCE, get_setting from irrd.conf.defaults import DEFAULT_SOURCE_NRTM_PORT +from irrd.routepref.routepref import update_route_preference_status from irrd.rpki.importer import ROADataImporter, ROAParserException from irrd.rpki.notifications import notify_rpki_invalid_owners from irrd.rpki.validators import BulkRouteROAValidator -from irrd.routepref.routepref import update_route_preference_status from irrd.scopefilter.validators import ScopeFilterValidator from irrd.storage.database_handler import DatabaseHandler from irrd.storage.event_stream import EventStreamPublisher from irrd.storage.queries import DatabaseStatusQuery from irrd.utils.whois_client import whois_query + from .parsers import MirrorFileImportParser, NRTMStreamParser logger = logging.getLogger(__name__) @@ -36,6 +37,7 @@ class RPSLMirrorImportUpdateRunner: to run a new import from full export files. Otherwise, will call NRTMImportUpdateStreamRunner to retrieve new updates from NRTM. """ + def __init__(self, source: str) -> None: self.source = source self.full_import_runner = RPSLMirrorFullImportRunner(source) @@ -46,13 +48,18 @@ def run(self) -> None: try: serial_newest_mirror, force_reload = self._status() - nrtm_enabled = bool(get_setting(f'sources.{self.source}.nrtm_host')) - logger.debug(f'Most recent mirrored serial for {self.source}: {serial_newest_mirror}, ' - f'force_reload: {force_reload}, nrtm enabled: {nrtm_enabled}') + nrtm_enabled = bool(get_setting(f"sources.{self.source}.nrtm_host")) + logger.debug( + f"Most recent mirrored serial for {self.source}: {serial_newest_mirror}, " + f"force_reload: {force_reload}, nrtm enabled: {nrtm_enabled}" + ) full_reload = force_reload or not serial_newest_mirror or not nrtm_enabled if full_reload: - self.full_import_runner.run(database_handler=self.database_handler, - serial_newest_mirror=serial_newest_mirror, force_reload=force_reload) + self.full_import_runner.run( + database_handler=self.database_handler, + serial_newest_mirror=serial_newest_mirror, + force_reload=force_reload, + ) else: assert serial_newest_mirror self.update_stream_runner.run(serial_newest_mirror, database_handler=self.database_handler) @@ -65,11 +72,18 @@ def run(self) -> None: except OSError as ose: # I/O errors can occur and should not log a full traceback (#177) - logger.error(f'An error occurred while attempting a mirror update or initial import ' - f'for {self.source}: {ose}') + logger.error( + "An error occurred while attempting a mirror update or initial import " + f"for {self.source}: {ose}" + ) except Exception as exc: - logger.error(f'An exception occurred while attempting a mirror update or initial import ' - f'for {self.source}: {exc}', exc_info=exc) + logger.error( + ( + "An exception occurred while attempting a mirror update or initial import " + f"for {self.source}: {exc}" + ), + exc_info=exc, + ) finally: self.database_handler.close() @@ -78,7 +92,7 @@ def _status(self) -> Tuple[Optional[int], Optional[bool]]: result = self.database_handler.execute_query(query) try: status = next(result) - return status['serial_newest_mirror'], status['force_reload'] + return status["serial_newest_mirror"], status["force_reload"] except StopIteration: return None, None @@ -101,12 +115,12 @@ def _retrieve_file(self, url: str, return_contents=True) -> Tuple[str, bool]: """ url_parsed = urlparse(url) - if url_parsed.scheme in ['ftp', 'http', 'https']: + if url_parsed.scheme in ["ftp", "http", "https"]: return self._retrieve_file_download(url, url_parsed, return_contents) - if url_parsed.scheme == 'file': + if url_parsed.scheme == "file": return self._retrieve_file_local(url_parsed.path, return_contents) - raise ValueError(f'Invalid URL: {url} - scheme {url_parsed.scheme} is not supported') + raise ValueError(f"Invalid URL: {url} - scheme {url_parsed.scheme} is not supported") def _retrieve_file_download(self, url, url_parsed, return_contents=False) -> Tuple[str, bool]: """ @@ -127,22 +141,22 @@ def _retrieve_file_download(self, url, url_parsed, return_contents=False) -> Tup destination = NamedTemporaryFile(delete=False) self._download_file(destination, url, url_parsed) if return_contents: - value = destination.getvalue().decode('utf-8').strip() # type: ignore - logger.info(f'Downloaded {url}, contained {value}') + value = destination.getvalue().decode("utf-8").strip() # type: ignore + logger.info(f"Downloaded {url}, contained {value}") return value, False else: - if url.endswith('.gz'): + if url.endswith(".gz"): zipped_file = destination zipped_file.close() destination = NamedTemporaryFile(delete=False) - logger.debug(f'Downloaded file is expected to be gzipped, gunzipping from {zipped_file.name}') - with gzip.open(zipped_file.name, 'rb') as f_in: + logger.debug(f"Downloaded file is expected to be gzipped, gunzipping from {zipped_file.name}") + with gzip.open(zipped_file.name, "rb") as f_in: shutil.copyfileobj(f_in, destination) os.unlink(zipped_file.name) destination.close() - logger.info(f'Downloaded (and gunzipped if applicable) {url} to {destination.name}') + logger.info(f"Downloaded (and gunzipped if applicable) {url} to {destination.name}") return destination.name, True def _download_file(self, destination: IO[Any], url: str, url_parsed): @@ -151,26 +165,26 @@ def _download_file(self, destination: IO[Any], url: str, url_parsed): The file contents are written to the destination parameter, which can be a BytesIO() or a regular file. """ - if url_parsed.scheme == 'ftp': + if url_parsed.scheme == "ftp": try: r = request.urlopen(url) shutil.copyfileobj(r, destination) except URLError as error: - raise OSError(f'Failed to download {url}: {str(error)}') - elif url_parsed.scheme in ['http', 'https']: + raise OSError(f"Failed to download {url}: {str(error)}") + elif url_parsed.scheme in ["http", "https"]: r = requests.get(url, stream=True, timeout=10) if r.status_code == 200: for chunk in r.iter_content(10240): destination.write(chunk) else: - raise OSError(f'Failed to download {url}: {r.status_code}: {str(r.content)}') + raise OSError(f"Failed to download {url}: {r.status_code}: {str(r.content)}") def _retrieve_file_local(self, path, return_contents=False) -> Tuple[str, bool]: if not return_contents: - if path.endswith('.gz'): + if path.endswith(".gz"): destination = NamedTemporaryFile(delete=False) - logger.debug(f'Local file is expected to be gzipped, gunzipping from {path}') - with gzip.open(path, 'rb') as f_in: + logger.debug(f"Local file is expected to be gzipped, gunzipping from {path}") + with gzip.open(path, "rb") as f_in: shutil.copyfileobj(f_in, destination) destination.close() return destination.name, True @@ -190,41 +204,63 @@ class RPSLMirrorFullImportRunner(FileImportRunnerBase): Files are downloaded, gunzipped if needed, and then sent through the MirrorFileImportParser. """ + def __init__(self, source: str) -> None: self.source = source - def run(self, database_handler: DatabaseHandler, serial_newest_mirror: Optional[int]=None, force_reload=False): - import_sources = get_setting(f'sources.{self.source}.import_source') + def run( + self, + database_handler: DatabaseHandler, + serial_newest_mirror: Optional[int] = None, + force_reload=False, + ): + import_sources = get_setting(f"sources.{self.source}.import_source") if isinstance(import_sources, str): import_sources = [import_sources] - import_serial_source = get_setting(f'sources.{self.source}.import_serial_source') + import_serial_source = get_setting(f"sources.{self.source}.import_serial_source") if not import_sources: - logger.info(f'Skipping full RPSL import for {self.source}, import_source not set.') + logger.info(f"Skipping full RPSL import for {self.source}, import_source not set.") return - logger.info(f'Running full RPSL import of {self.source} from {import_sources}, serial from {import_serial_source}') + logger.info( + f"Running full RPSL import of {self.source} from {import_sources}, serial from" + f" {import_serial_source}" + ) import_serial = None if import_serial_source: import_serial = int(self._retrieve_file(import_serial_source, return_contents=True)[0]) - if not force_reload and serial_newest_mirror is not None and import_serial <= serial_newest_mirror: - logger.info(f'Current newest serial seen from mirror for {self.source} is ' - f'{serial_newest_mirror}, import_serial is {import_serial}, cancelling import.') + if ( + not force_reload + and serial_newest_mirror is not None + and import_serial <= serial_newest_mirror + ): + logger.info( + f"Current newest serial seen from mirror for {self.source} is " + f"{serial_newest_mirror}, import_serial is {import_serial}, cancelling import." + ) return database_handler.delete_all_rpsl_objects_with_journal(self.source) - import_data = [self._retrieve_file(import_source, return_contents=False) for import_source in import_sources] + import_data = [ + self._retrieve_file(import_source, return_contents=False) for import_source in import_sources + ] roa_validator = None - if get_setting('rpki.roa_source'): + if get_setting("rpki.roa_source"): roa_validator = BulkRouteROAValidator(database_handler) database_handler.disable_journaling() for import_filename, to_delete in import_data: - p = MirrorFileImportParser(source=self.source, filename=import_filename, serial=None, - database_handler=database_handler, roa_validator=roa_validator) + p = MirrorFileImportParser( + source=self.source, + filename=import_filename, + serial=None, + database_handler=database_handler, + roa_validator=roa_validator, + ) p.run_import() if to_delete: os.unlink(import_filename) @@ -238,6 +274,7 @@ class ROAImportRunner(FileImportRunnerBase): The URL file for the ROA export in JSON format is provided in the configuration. """ + # API consistency with other importers, source is actually ignored def __init__(self, source=None): pass @@ -263,25 +300,27 @@ def run(self): ) self.database_handler.commit() notified = notify_rpki_invalid_owners(self.database_handler, objs_now_invalid) - logger.info(f'RPKI status updated for all routes, {len(objs_now_valid)} newly valid, ' - f'{len(objs_now_invalid)} newly invalid, ' - f'{len(objs_now_not_found)} newly not_found routes, ' - f'{notified} emails sent to contacts of newly invalid authoritative objects') + logger.info( + f"RPKI status updated for all routes, {len(objs_now_valid)} newly valid, " + f"{len(objs_now_invalid)} newly invalid, " + f"{len(objs_now_not_found)} newly not_found routes, " + f"{notified} emails sent to contacts of newly invalid authoritative objects" + ) except OSError as ose: # I/O errors can occur and should not log a full traceback (#177) - logger.error(f'An error occurred while attempting a ROA import: {ose}') + logger.error(f"An error occurred while attempting a ROA import: {ose}") except ROAParserException as rpe: - logger.error(f'An exception occurred while attempting a ROA import: {rpe}') + logger.error(f"An exception occurred while attempting a ROA import: {rpe}") except Exception as exc: - logger.error(f'An exception occurred while attempting a ROA import: {exc}', exc_info=exc) + logger.error(f"An exception occurred while attempting a ROA import: {exc}", exc_info=exc) finally: self.database_handler.close() def _import_roas(self): - roa_source = get_setting('rpki.roa_source') - slurm_source = get_setting('rpki.slurm_source') - logger.info(f'Running full ROA import from: {roa_source}, SLURM {slurm_source}') + roa_source = get_setting("rpki.roa_source") + slurm_source = get_setting("rpki.slurm_source") + logger.info(f"Running full ROA import from: {roa_source}, SLURM {slurm_source}") self.database_handler.delete_all_roa_objects() self.database_handler.delete_all_rpsl_objects_with_journal( @@ -298,7 +337,10 @@ def _import_roas(self): roa_importer = ROADataImporter(fh.read(), slurm_data, self.database_handler) if roa_to_delete: os.unlink(roa_filename) - logger.info(f'ROA import from {roa_source}, SLURM {slurm_source}, imported {len(roa_importer.roa_objs)} ROAs, running validator') + logger.info( + f"ROA import from {roa_source}, SLURM {slurm_source}, imported {len(roa_importer.roa_objs)} ROAs," + " running validator" + ) return roa_importer.roa_objs @@ -308,6 +350,7 @@ class ScopeFilterUpdateRunner: This runner does not actually import anything, the scope filter is in the configuration. """ + # API consistency with other importers, source is actually ignored def __init__(self, source=None): pass @@ -325,13 +368,17 @@ def run(self): rpsl_objs_now_out_scope_prefix=rpsl_objs_now_out_scope_prefix, ) self.database_handler.commit() - logger.info(f'Scopefilter status updated for all routes, ' - f'{len(rpsl_objs_now_in_scope)} newly in scope, ' - f'{len(rpsl_objs_now_out_scope_as)} newly out of scope AS, ' - f'{len(rpsl_objs_now_out_scope_prefix)} newly out of scope prefix') + logger.info( + "Scopefilter status updated for all routes, " + f"{len(rpsl_objs_now_in_scope)} newly in scope, " + f"{len(rpsl_objs_now_out_scope_as)} newly out of scope AS, " + f"{len(rpsl_objs_now_out_scope_prefix)} newly out of scope prefix" + ) except Exception as exc: - logger.error(f'An exception occurred while attempting a scopefilter status update: {exc}', exc_info=exc) + logger.error( + f"An exception occurred while attempting a scopefilter status update: {exc}", exc_info=exc + ) finally: self.database_handler.close() @@ -342,6 +389,7 @@ class RoutePreferenceUpdateRunner: This runner does not actually import anything external, all data is already in our database. """ + # API consistency with other importers, source is actually ignored def __init__(self, source=None): pass @@ -352,9 +400,12 @@ def run(self): try: update_route_preference_status(database_handler) database_handler.commit() - logger.info('route preference update commit complete') + logger.info("route preference update commit complete") except Exception as exc: - logger.error(f'An exception occurred while attempting a route preference status update: {exc}', exc_info=exc) + logger.error( + f"An exception occurred while attempting a route preference status update: {exc}", + exc_info=exc, + ) finally: database_handler.close() @@ -364,30 +415,33 @@ class NRTMImportUpdateStreamRunner: This runner attempts to pull updates from an NRTM stream for a specific mirrored database. """ + def __init__(self, source: str) -> None: self.source = source def run(self, serial_newest_mirror: int, database_handler: DatabaseHandler): serial_start = serial_newest_mirror + 1 - nrtm_host = get_setting(f'sources.{self.source}.nrtm_host') - nrtm_port = int(get_setting(f'sources.{self.source}.nrtm_port', DEFAULT_SOURCE_NRTM_PORT)) + nrtm_host = get_setting(f"sources.{self.source}.nrtm_host") + nrtm_port = int(get_setting(f"sources.{self.source}.nrtm_port", DEFAULT_SOURCE_NRTM_PORT)) if not nrtm_host: - logger.debug(f'Skipping NRTM updates for {self.source}, nrtm_host not set.') + logger.debug(f"Skipping NRTM updates for {self.source}, nrtm_host not set.") return end_markings = [ - f'\n%END {self.source}\n', - f'\n% END {self.source}\n', - '\n%ERROR', - '\n% ERROR', - '\n% Warning: there are no newer updates available', - '\n% Warning (1): there are no newer updates available', + f"\n%END {self.source}\n", + f"\n% END {self.source}\n", + "\n%ERROR", + "\n% ERROR", + "\n% Warning: there are no newer updates available", + "\n% Warning (1): there are no newer updates available", ] - logger.info(f'Retrieving NRTM updates for {self.source} from serial {serial_start} on {nrtm_host}:{nrtm_port}') - query = f'-g {self.source}:3:{serial_start}-LAST' + logger.info( + f"Retrieving NRTM updates for {self.source} from serial {serial_start} on {nrtm_host}:{nrtm_port}" + ) + query = f"-g {self.source}:3:{serial_start}-LAST" response = whois_query(nrtm_host, nrtm_port, query, end_markings) - logger.debug(f'Received NRTM response for {self.source}: {response.strip()}') + logger.debug(f"Received NRTM response for {self.source}: {response.strip()}") stream_parser = NRTMStreamParser(self.source, response, database_handler) for operation in stream_parser.operations: diff --git a/irrd/mirroring/nrtm_generator.py b/irrd/mirroring/nrtm_generator.py index 81a99d30a..3b18c1869 100644 --- a/irrd/mirroring/nrtm_generator.py +++ b/irrd/mirroring/nrtm_generator.py @@ -2,7 +2,7 @@ from irrd.conf import get_setting from irrd.storage.database_handler import DatabaseHandler -from irrd.storage.queries import RPSLDatabaseJournalQuery, DatabaseStatusQuery +from irrd.storage.queries import DatabaseStatusQuery, RPSLDatabaseJournalQuery from irrd.utils.text import remove_auth_hashes as remove_auth_hashes_func @@ -11,9 +11,15 @@ class NRTMGeneratorException(Exception): # noqa: N818 class NRTMGenerator: - def generate(self, source: str, version: str, - serial_start_requested: int, serial_end_requested: Optional[int], - database_handler: DatabaseHandler, remove_auth_hashes=True) -> str: + def generate( + self, + source: str, + version: str, + serial_start_requested: int, + serial_end_requested: Optional[int], + database_handler: DatabaseHandler, + remove_auth_hashes=True, + ) -> str: """ Generate an NRTM response for a particular source, serial range and NRTM version. Raises NRTMGeneratorException for various error conditions. @@ -21,59 +27,70 @@ def generate(self, source: str, version: str, For queries where the user requested NRTM updates up to LAST, serial_end_requested is None. """ - if not get_setting(f'sources.{source}.keep_journal'): - raise NRTMGeneratorException('No journal kept for this source, unable to serve NRTM queries') + if not get_setting(f"sources.{source}.keep_journal"): + raise NRTMGeneratorException("No journal kept for this source, unable to serve NRTM queries") q = DatabaseStatusQuery().source(source) try: status = next(database_handler.execute_query(q)) except StopIteration: - raise NRTMGeneratorException('There are no journal entries for this source.') + raise NRTMGeneratorException("There are no journal entries for this source.") if serial_end_requested and serial_end_requested < serial_start_requested: - raise NRTMGeneratorException(f'Start of the serial range ({serial_start_requested}) must be lower or ' - f'equal to end of the serial range ({serial_end_requested})') + raise NRTMGeneratorException( + f"Start of the serial range ({serial_start_requested}) must be lower or " + f"equal to end of the serial range ({serial_end_requested})" + ) - serial_start_available = status['serial_oldest_journal'] - serial_end_available = status['serial_newest_journal'] + serial_start_available = status["serial_oldest_journal"] + serial_end_available = status["serial_newest_journal"] if serial_start_available is None or serial_end_available is None: - return '% Warning: there are no updates available' + return "% Warning: there are no updates available" if serial_start_requested < serial_start_available: - raise NRTMGeneratorException(f'Serials {serial_start_requested} - {serial_start_available} do not exist') + raise NRTMGeneratorException( + f"Serials {serial_start_requested} - {serial_start_available} do not exist" + ) if serial_end_requested is not None and serial_end_requested > serial_end_available: - raise NRTMGeneratorException(f'Serials {serial_end_available} - {serial_end_requested} do not exist') + raise NRTMGeneratorException( + f"Serials {serial_end_available} - {serial_end_requested} do not exist" + ) if serial_end_requested is None: if serial_start_requested == serial_end_available + 1: # A specific message is triggered when starting from a serial # that is the current plus one, until LAST - return '% Warning: there are no newer updates available' + return "% Warning: there are no newer updates available" elif serial_start_requested > serial_end_available: raise NRTMGeneratorException( - f'Serials {serial_end_available} - {serial_start_requested} do not exist') + f"Serials {serial_end_available} - {serial_start_requested} do not exist" + ) serial_end_display = serial_end_available if serial_end_requested is None else serial_end_requested - range_limit = get_setting(f'sources.{source}.nrtm_query_serial_range_limit') + range_limit = get_setting(f"sources.{source}.nrtm_query_serial_range_limit") if range_limit and int(range_limit) < (serial_end_display - serial_start_requested): - raise NRTMGeneratorException(f'Serial range requested exceeds maximum range of {range_limit}') + raise NRTMGeneratorException(f"Serial range requested exceeds maximum range of {range_limit}") - q = RPSLDatabaseJournalQuery().sources([source]).serial_nrtm_range(serial_start_requested, serial_end_requested) + q = ( + RPSLDatabaseJournalQuery() + .sources([source]) + .serial_nrtm_range(serial_start_requested, serial_end_requested) + ) operations = list(database_handler.execute_query(q)) - output = f'%START Version: {version} {source} {serial_start_requested}-{serial_end_display}\n' + output = f"%START Version: {version} {source} {serial_start_requested}-{serial_end_display}\n" for operation in operations: - output += '\n' + operation['operation'].value - if version == '3': - output += ' ' + str(operation['serial_nrtm']) - text = operation['object_text'] + output += "\n" + operation["operation"].value + if version == "3": + output += " " + str(operation["serial_nrtm"]) + text = operation["object_text"] if remove_auth_hashes: text = remove_auth_hashes_func(text) - output += '\n\n' + text + output += "\n\n" + text - output += f'\n%END {source}' + output += f"\n%END {source}" return output diff --git a/irrd/mirroring/nrtm_operation.py b/irrd/mirroring/nrtm_operation.py index 7c32780b5..ae5a6d064 100644 --- a/irrd/mirroring/nrtm_operation.py +++ b/irrd/mirroring/nrtm_operation.py @@ -1,9 +1,9 @@ import logging -from typing import Optional, List +from typing import List, Optional from irrd.rpki.validators import SingleRouteROAValidator from irrd.rpsl.parser import UnknownRPSLObjectClassException -from irrd.rpsl.rpsl_objects import rpsl_object_from_text, RPSLKeyCert +from irrd.rpsl.rpsl_objects import RPSLKeyCert, rpsl_object_from_text from irrd.scopefilter.validators import ScopeFilterValidator from irrd.storage.database_handler import DatabaseHandler from irrd.storage.models import DatabaseOperation, JournalEntryOrigin @@ -20,9 +20,17 @@ class NRTMOperation: source attribute, but with other PK attribute(s) present. For deletion operations, this is permitted. """ - def __init__(self, source: str, operation: DatabaseOperation, serial: int, object_text: str, - strict_validation_key_cert: bool, rpki_aware: bool = False, - object_class_filter: Optional[List[str]] = None) -> None: + + def __init__( + self, + source: str, + operation: DatabaseOperation, + serial: int, + object_text: str, + strict_validation_key_cert: bool, + rpki_aware: bool = False, + object_class_filter: Optional[List[str]] = None, + ) -> None: self.source = source self.operation = operation self.serial = serial @@ -39,32 +47,40 @@ def save(self, database_handler: DatabaseHandler) -> bool: # is set, parse it again with strict validation to load it in the GPG keychain. obj = rpsl_object_from_text(object_text, strict_validation=False, default_source=default_source) if self.strict_validation_key_cert and obj.__class__ == RPSLKeyCert: - obj = rpsl_object_from_text(object_text, strict_validation=True, default_source=default_source) + obj = rpsl_object_from_text( + object_text, strict_validation=True, default_source=default_source + ) except UnknownRPSLObjectClassException as exc: # Unknown object classes are only logged if they have not been filtered out. if not self.object_class_filter or exc.rpsl_object_class.lower() in self.object_class_filter: - logger.info(f'Ignoring NRTM operation {str(self)}: {exc}') + logger.info(f"Ignoring NRTM operation {str(self)}: {exc}") return False if self.object_class_filter and obj.rpsl_object_class.lower() not in self.object_class_filter: return False if obj.messages.errors(): - errors = '; '.join(obj.messages.errors()) - logger.critical(f'Parsing errors occurred while processing NRTM operation {str(self)}. ' - f'This operation is ignored, causing potential data inconsistencies. ' - f'A new operation for this update, without errors, ' - f'will still be processed and cause the inconsistency to be resolved. ' - f'Parser error messages: {errors}; original object text follows:\n{self.object_text}') - database_handler.record_mirror_error(self.source, f'Parsing errors: {obj.messages.errors()}, ' - f'original object text follows:\n{self.object_text}') + errors = "; ".join(obj.messages.errors()) + logger.critical( + f"Parsing errors occurred while processing NRTM operation {str(self)}. " + "This operation is ignored, causing potential data inconsistencies. " + "A new operation for this update, without errors, " + "will still be processed and cause the inconsistency to be resolved. " + f"Parser error messages: {errors}; original object text follows:\n{self.object_text}" + ) + database_handler.record_mirror_error( + self.source, + f"Parsing errors: {obj.messages.errors()}, original object text follows:\n{self.object_text}", + ) return False - if 'source' in obj.parsed_data and obj.parsed_data['source'].upper() != self.source: - msg = (f'Incorrect source in NRTM object: stream has source {self.source}, found object with ' - f'source {obj.source()} in operation {self.serial}/{self.operation.value}/{obj.pk()}. ' - f'This operation is ignored, causing potential data inconsistencies.') + if "source" in obj.parsed_data and obj.parsed_data["source"].upper() != self.source: + msg = ( + f"Incorrect source in NRTM object: stream has source {self.source}, found object with " + f"source {obj.source()} in operation {self.serial}/{self.operation.value}/{obj.pk()}. " + "This operation is ignored, causing potential data inconsistencies." + ) database_handler.record_mirror_error(self.source, msg) logger.critical(msg) return False @@ -75,17 +91,17 @@ def save(self, database_handler: DatabaseHandler) -> bool: obj.rpki_status = roa_validator.validate_route(obj.prefix, obj.asn_first, obj.source()) scope_validator = ScopeFilterValidator() obj.scopefilter_status, _ = scope_validator.validate_rpsl_object(obj) - database_handler.upsert_rpsl_object(obj, JournalEntryOrigin.mirror, - source_serial=self.serial) + database_handler.upsert_rpsl_object(obj, JournalEntryOrigin.mirror, source_serial=self.serial) elif self.operation == DatabaseOperation.delete: - database_handler.delete_rpsl_object(rpsl_object=obj, origin=JournalEntryOrigin.mirror, - source_serial=self.serial) + database_handler.delete_rpsl_object( + rpsl_object=obj, origin=JournalEntryOrigin.mirror, source_serial=self.serial + ) - log = f'Completed NRTM operation {str(self)}/{obj.rpsl_object_class}/{obj.pk()}' + log = f"Completed NRTM operation {str(self)}/{obj.rpsl_object_class}/{obj.pk()}" if self.rpki_aware and obj.is_route: - log += f', RPKI status {obj.rpki_status.value}' + log += f", RPKI status {obj.rpki_status.value}" logger.info(log) return True def __repr__(self): - return f'{self.source}/{self.serial}/{self.operation.value}' + return f"{self.source}/{self.serial}/{self.operation.value}" diff --git a/irrd/mirroring/parsers.py b/irrd/mirroring/parsers.py index ef571bece..fe8fb88f4 100644 --- a/irrd/mirroring/parsers.py +++ b/irrd/mirroring/parsers.py @@ -1,20 +1,27 @@ import logging import re -from typing import List, Set, Optional +from typing import List, Optional, Set from irrd.conf import get_setting from irrd.rpki.validators import BulkRouteROAValidator -from irrd.rpsl.parser import UnknownRPSLObjectClassException, RPSLObject -from irrd.rpsl.rpsl_objects import rpsl_object_from_text, RPSLKeyCert +from irrd.rpsl.parser import RPSLObject, UnknownRPSLObjectClassException +from irrd.rpsl.rpsl_objects import RPSLKeyCert, rpsl_object_from_text from irrd.scopefilter.validators import ScopeFilterValidator from irrd.storage.database_handler import DatabaseHandler from irrd.storage.models import DatabaseOperation, JournalEntryOrigin -from irrd.utils.text import split_paragraphs_rpsl, remove_last_modified -from .nrtm_operation import NRTMOperation +from irrd.utils.text import remove_last_modified, split_paragraphs_rpsl + from ..storage.queries import RPSLDatabaseQuery +from .nrtm_operation import NRTMOperation logger = logging.getLogger(__name__) -nrtm_start_line_re = re.compile(r'^% *START *Version: *(?P\d+) +(?P[\w-]+) +(?P\d+)-(?P\d+)( FILTERED)?\n$', flags=re.MULTILINE) +nrtm_start_line_re = re.compile( + ( + r"^% *START *Version: *(?P\d+) +(?P[\w-]+)" + r" +(?P\d+)-(?P\d+)( FILTERED)?\n$" + ), + flags=re.MULTILINE, +) class RPSLImportError(Exception): @@ -24,7 +31,7 @@ def __init__(self, message: str) -> None: class MirrorParser: def __init__(self): - object_class_filter = get_setting(f'sources.{self.source}.object_class_filter') + object_class_filter = get_setting(f"sources.{self.source}.object_class_filter") if object_class_filter: if isinstance(object_class_filter, str): object_class_filter = [object_class_filter] @@ -32,7 +39,9 @@ def __init__(self): else: self.object_class_filter = None - self.strict_validation_key_cert = get_setting(f'sources.{self.source}.strict_import_keycert_objects', False) + self.strict_validation_key_cert = get_setting( + f"sources.{self.source}.strict_import_keycert_objects", False + ) class MirrorFileImportParserBase(MirrorParser): @@ -45,15 +54,21 @@ class MirrorFileImportParserBase(MirrorParser): upon an encountering an error message. It will return an error string. """ + obj_parsed = 0 # Total objects found obj_errors = 0 # Objects with errors obj_ignored_class = 0 # Objects ignored due to object_class_filter setting obj_unknown = 0 # Objects with unknown classes unknown_object_classes: Set[str] = set() # Set of encountered unknown classes - def __init__(self, source: str, filename: str, - database_handler: DatabaseHandler, direct_error_return: bool=False, - roa_validator: Optional[BulkRouteROAValidator] = None) -> None: + def __init__( + self, + source: str, + filename: str, + database_handler: DatabaseHandler, + direct_error_return: bool = False, + roa_validator: Optional[BulkRouteROAValidator] = None, + ) -> None: self.source = source self.filename = filename self.database_handler = database_handler @@ -83,23 +98,27 @@ def _parse_object(self, rpsl_text: str) -> Optional[RPSLObject]: obj = rpsl_object_from_text(rpsl_text.strip(), strict_validation=True) if obj.messages.errors(): - log_msg = f'Parsing errors: {obj.messages.errors()}, original object text follows:\n{rpsl_text}' + log_msg = ( + f"Parsing errors: {obj.messages.errors()}, original object text follows:\n{rpsl_text}" + ) if self.direct_error_return: raise RPSLImportError(log_msg) self.database_handler.record_mirror_error(self.source, log_msg) - logger.critical(f'Parsing errors occurred while importing from file for {self.source}. ' - f'This object is ignored, causing potential data inconsistencies. A new operation for ' - f'this update, without errors, will still be processed and cause the inconsistency to ' - f'be resolved. Parser error messages: {obj.messages.errors()}; ' - f'original object text follows:\n{rpsl_text}') + logger.critical( + f"Parsing errors occurred while importing from file for {self.source}. " + "This object is ignored, causing potential data inconsistencies. A new operation for " + "this update, without errors, will still be processed and cause the inconsistency to " + f"be resolved. Parser error messages: {obj.messages.errors()}; " + f"original object text follows:\n{rpsl_text}" + ) self.obj_errors += 1 return None if obj.source() != self.source: - msg = f'Invalid source {obj.source()} for object {obj.pk()}, expected {self.source}' + msg = f"Invalid source {obj.source()} for object {obj.pk()}, expected {self.source}" if self.direct_error_return: raise RPSLImportError(msg) - logger.critical(msg + '. This object is ignored, causing potential data inconsistencies.') + logger.critical(msg + ". This object is ignored, causing potential data inconsistencies.") self.database_handler.record_mirror_error(self.source, msg) self.obj_errors += 1 return None @@ -120,11 +139,11 @@ def _parse_object(self, rpsl_text: str) -> Optional[RPSLObject]: except UnknownRPSLObjectClassException as e: # Ignore legacy IRRd artifacts # https://github.com/irrdnet/irrd4/issues/232 - if e.rpsl_object_class.startswith('*xx'): + if e.rpsl_object_class.startswith("*xx"): self.obj_parsed -= 1 # This object does not exist to us return None if self.direct_error_return: - raise RPSLImportError(f'Unknown object class: {e.rpsl_object_class}') + raise RPSLImportError(f"Unknown object class: {e.rpsl_object_class}") self.obj_unknown += 1 self.unknown_object_classes.add(e.rpsl_object_class) return None @@ -140,17 +159,18 @@ class MirrorFileImportParser(MirrorFileImportParserBase): upon an encountering an error message. It will return an error string. """ - def __init__(self, serial: Optional[int]=None, *args, **kwargs): + + def __init__(self, serial: Optional[int] = None, *args, **kwargs): super().__init__(*args, **kwargs) self.serial = serial - logger.debug(f'Starting file import of {self.source} from {self.filename}') + logger.debug(f"Starting file import of {self.source} from {self.filename}") def run_import(self) -> Optional[str]: """ Run the actual import. If direct_error_return is set, returns an error string on encountering the first error. Otherwise, returns None. """ - f = open(self.filename, encoding='utf-8', errors='backslashreplace') + f = open(self.filename, encoding="utf-8", errors="backslashreplace") for paragraph in split_paragraphs_rpsl(f): try: rpsl_obj = self._parse_object(paragraph) @@ -159,8 +179,7 @@ def run_import(self) -> Optional[str]: return e.message else: if rpsl_obj: - self.database_handler.upsert_rpsl_object( - rpsl_obj, origin=JournalEntryOrigin.mirror) + self.database_handler.upsert_rpsl_object(rpsl_obj, origin=JournalEntryOrigin.mirror) self.log_report() f.close() @@ -171,15 +190,19 @@ def run_import(self) -> Optional[str]: def log_report(self) -> None: obj_successful = self.obj_parsed - self.obj_unknown - self.obj_errors - self.obj_ignored_class - logger.info(f'File import for {self.source}: {self.obj_parsed} objects read, ' - f'{obj_successful} objects inserted, ' - f'ignored {self.obj_errors} due to errors, ' - f'ignored {self.obj_ignored_class} due to object_class_filter, ' - f'source {self.filename}') + logger.info( + f"File import for {self.source}: {self.obj_parsed} objects read, " + f"{obj_successful} objects inserted, " + f"ignored {self.obj_errors} due to errors, " + f"ignored {self.obj_ignored_class} due to object_class_filter, " + f"source {self.filename}" + ) if self.obj_unknown: - unknown_formatted = ', '.join(self.unknown_object_classes) - logger.warning(f'Ignored {self.obj_unknown} objects found in file import for {self.source} due to unknown ' - f'object classes: {unknown_formatted}') + unknown_formatted = ", ".join(self.unknown_object_classes) + logger.warning( + f"Ignored {self.obj_unknown} objects found in file import for {self.source} due to unknown " + f"object classes: {unknown_formatted}" + ) class MirrorUpdateFileImportParser(MirrorFileImportParserBase): @@ -195,9 +218,10 @@ class MirrorUpdateFileImportParser(MirrorFileImportParserBase): upon an encountering an error message. It will return an error string. """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - logger.debug(f'Starting update import for {self.source} from {self.filename}') + logger.debug(f"Starting update import for {self.source} from {self.filename}") self.obj_new = 0 # New objects self.obj_modified = 0 # Modified objects self.obj_retained = 0 # Retained and possibly modified objects @@ -209,7 +233,7 @@ def run_import(self) -> Optional[str]: string on encountering the first error. Otherwise, returns None. """ objs_from_file = [] - f = open(self.filename, encoding='utf-8', errors='backslashreplace') + f = open(self.filename, encoding="utf-8", errors="backslashreplace") for paragraph in split_paragraphs_rpsl(f): try: rpsl_obj = self._parse_object(paragraph) @@ -221,17 +245,14 @@ def run_import(self) -> Optional[str]: objs_from_file.append(rpsl_obj) f.close() - query = RPSLDatabaseQuery(ordered_by_sources=False, enable_ordering=False, - column_names=['rpsl_pk', 'object_class']).sources([self.source]) + query = RPSLDatabaseQuery( + ordered_by_sources=False, enable_ordering=False, column_names=["rpsl_pk", "object_class"] + ).sources([self.source]) current_pks = { - (row['rpsl_pk'], row['object_class']) - for row in self.database_handler.execute_query(query) + (row["rpsl_pk"], row["object_class"]) for row in self.database_handler.execute_query(query) } - file_objs_by_pk = { - (obj.pk(), obj.rpsl_object_class): obj - for obj in objs_from_file - } + file_objs_by_pk = {(obj.pk(), obj.rpsl_object_class): obj for obj in objs_from_file} file_pks = set(file_objs_by_pk.keys()) new_pks = file_pks - current_pks deleted_pks = current_pks - file_pks @@ -244,23 +265,28 @@ def run_import(self) -> Optional[str]: for (rpsl_pk, object_class), file_obj in filter(lambda i: i[0] in new_pks, file_objs_by_pk.items()): self.database_handler.upsert_rpsl_object(file_obj, JournalEntryOrigin.synthetic_nrtm) - for (rpsl_pk, object_class) in deleted_pks: + for rpsl_pk, object_class in deleted_pks: self.database_handler.delete_rpsl_object( - rpsl_pk=rpsl_pk, source=self.source, object_class=object_class, + rpsl_pk=rpsl_pk, + source=self.source, + object_class=object_class, origin=JournalEntryOrigin.synthetic_nrtm, ) # This query does not filter on retained_pks. The expectation is that most # objects are retained, and therefore it is much faster to query the entire source. - query = RPSLDatabaseQuery(ordered_by_sources=False, enable_ordering=False, - column_names=['rpsl_pk', 'object_class', 'object_text']) + query = RPSLDatabaseQuery( + ordered_by_sources=False, + enable_ordering=False, + column_names=["rpsl_pk", "object_class", "object_text"], + ) query = query.sources([self.source]) for row in self.database_handler.execute_query(query): try: - file_obj = file_objs_by_pk[(row['rpsl_pk'], row['object_class'])] + file_obj = file_objs_by_pk[(row["rpsl_pk"], row["object_class"])] except KeyError: continue - if file_obj.render_rpsl_text() != remove_last_modified(row['object_text']): + if file_obj.render_rpsl_text() != remove_last_modified(row["object_text"]): self.database_handler.upsert_rpsl_object(file_obj, JournalEntryOrigin.synthetic_nrtm) self.obj_modified += 1 @@ -269,19 +295,23 @@ def run_import(self) -> Optional[str]: def log_report(self) -> None: obj_successful = self.obj_parsed - self.obj_unknown - self.obj_errors - self.obj_ignored_class - logger.info(f'File update for {self.source}: {self.obj_parsed} objects read, ' - f'{obj_successful} objects processed, ' - f'{self.obj_new} objects newly inserted, ' - f'{self.obj_deleted} objects newly deleted, ' - f'{self.obj_retained} objects retained, of which ' - f'{self.obj_modified} modified, ' - f'ignored {self.obj_errors} due to errors, ' - f'ignored {self.obj_ignored_class} due to object_class_filter, ' - f'source {self.filename}') + logger.info( + f"File update for {self.source}: {self.obj_parsed} objects read, " + f"{obj_successful} objects processed, " + f"{self.obj_new} objects newly inserted, " + f"{self.obj_deleted} objects newly deleted, " + f"{self.obj_retained} objects retained, of which " + f"{self.obj_modified} modified, " + f"ignored {self.obj_errors} due to errors, " + f"ignored {self.obj_ignored_class} due to object_class_filter, " + f"source {self.filename}" + ) if self.obj_unknown: - unknown_formatted = ', '.join(self.unknown_object_classes) - logger.warning(f'Ignored {self.obj_unknown} objects found in file import for {self.source} due to unknown ' - f'object classes: {unknown_formatted}') + unknown_formatted = ", ".join(self.unknown_object_classes) + logger.warning( + f"Ignored {self.obj_unknown} objects found in file import for {self.source} due to unknown " + f"object classes: {unknown_formatted}" + ) class NRTMStreamParser(MirrorParser): @@ -298,6 +328,7 @@ class NRTMStreamParser(MirrorParser): Raises a ValueError for invalid NRTM data. """ + first_serial = -1 last_serial = -1 nrtm_source: Optional[str] = None @@ -306,7 +337,7 @@ class NRTMStreamParser(MirrorParser): def __init__(self, source: str, nrtm_data: str, database_handler: DatabaseHandler) -> None: self.source = source self.database_handler = database_handler - self.rpki_aware = bool(get_setting('rpki.roa_source')) + self.rpki_aware = bool(get_setting("rpki.roa_source")) super().__init__() self.operations: List[NRTMOperation] = [] self._split_stream(nrtm_data) @@ -314,28 +345,32 @@ def __init__(self, source: str, nrtm_data: str, database_handler: DatabaseHandle def _split_stream(self, data: str) -> None: """Split a stream into individual operations.""" paragraphs = split_paragraphs_rpsl(data, strip_comments=False) - last_comment_seen = '' + last_comment_seen = "" for paragraph in paragraphs: if self._handle_possible_start_line(paragraph): continue - elif paragraph.startswith('%') or paragraph.startswith('#'): + elif paragraph.startswith("%") or paragraph.startswith("#"): last_comment_seen = paragraph - elif paragraph.startswith('ADD') or paragraph.startswith('DEL'): + elif paragraph.startswith("ADD") or paragraph.startswith("DEL"): self._handle_operation(paragraph, paragraphs) - if self.nrtm_source and last_comment_seen.upper().strip() != f'%END {self.source}': - msg = f'NRTM stream error for {self.source}: last comment paragraph expected to be ' \ - f'"%END {self.source}", but is actually "{last_comment_seen.upper().strip()}" - ' \ - 'could be caused by TCP disconnection during NRTM query or mirror server ' \ - 'returning an error or an otherwise incomplete or invalid response' + if self.nrtm_source and last_comment_seen.upper().strip() != f"%END {self.source}": + msg = ( + f"NRTM stream error for {self.source}: last comment paragraph expected to be " + f'"%END {self.source}", but is actually "{last_comment_seen.upper().strip()}" - ' + "could be caused by TCP disconnection during NRTM query or mirror server " + "returning an error or an otherwise incomplete or invalid response" + ) logger.error(msg) self.database_handler.record_mirror_error(self.source, msg) raise ValueError(msg) - if self._current_op_serial > self.last_serial and self.version != '3': - msg = f'NRTM stream error for {self.source}: expected operations up to and including serial ' \ - f'{self.last_serial}, last operation was {self._current_op_serial}' + if self._current_op_serial > self.last_serial and self.version != "3": + msg = ( + f"NRTM stream error for {self.source}: expected operations up to and including serial " + f"{self.last_serial}, last operation was {self._current_op_serial}" + ) logger.error(msg) self.database_handler.record_mirror_error(self.source, msg) raise ValueError(msg) @@ -350,38 +385,47 @@ def _handle_possible_start_line(self, line: str) -> bool: return False if self.nrtm_source: # nrtm_source can only be defined if this is a second START line - msg = f'Encountered second START line in NRTM stream, first was {self.source} ' \ - f'{self.first_serial}-{self.last_serial}, new line is: {line}' + msg = ( + f"Encountered second START line in NRTM stream, first was {self.source} " + f"{self.first_serial}-{self.last_serial}, new line is: {line}" + ) self.database_handler.record_mirror_error(self.source, msg) logger.error(msg) raise ValueError(msg) - self.version = start_line_match.group('version') - self.nrtm_source = start_line_match.group('source').upper() - self.first_serial = int(start_line_match.group('first_serial')) - self.last_serial = int(start_line_match.group('last_serial')) + self.version = start_line_match.group("version") + self.nrtm_source = start_line_match.group("source").upper() + self.first_serial = int(start_line_match.group("first_serial")) + self.last_serial = int(start_line_match.group("last_serial")) if self.source != self.nrtm_source: - msg = f'Invalid NRTM source in START line: expected {self.source} but found ' \ - f'{self.nrtm_source} in line: {line}' + msg = ( + f"Invalid NRTM source in START line: expected {self.source} but found " + f"{self.nrtm_source} in line: {line}" + ) self.database_handler.record_mirror_error(self.source, msg) logger.error(msg) raise ValueError(msg) - if self.version not in ['1', '3']: - msg = f'Invalid NRTM version {self.version} in START line: {line}' + if self.version not in ["1", "3"]: + msg = f"Invalid NRTM version {self.version} in START line: {line}" self.database_handler.record_mirror_error(self.source, msg) logger.error(msg) raise ValueError(msg) - logger.debug(f'Found valid start line for {self.source}, range {self.first_serial}-{self.last_serial}') + logger.debug( + f"Found valid start line for {self.source}, range {self.first_serial}-{self.last_serial}" + ) return True def _handle_operation(self, current_paragraph: str, paragraphs) -> None: """Handle a single ADD/DEL operation.""" if not self.nrtm_source: - msg = f'Encountered operation before valid NRTM START line, paragraph encountered: {current_paragraph}' + msg = ( + "Encountered operation before valid NRTM START line, paragraph encountered:" + f" {current_paragraph}" + ) self.database_handler.record_mirror_error(self.source, msg) logger.error(msg) raise ValueError(msg) @@ -391,14 +435,16 @@ def _handle_operation(self, current_paragraph: str, paragraphs) -> None: else: self._current_op_serial += 1 - if ' ' in current_paragraph: - operation_str, line_serial_str = current_paragraph.split(' ') + if " " in current_paragraph: + operation_str, line_serial_str = current_paragraph.split(" ") line_serial = int(line_serial_str) # Gaps are allowed, but the line serial can never be lower, as that # means operations are served in the wrong order. if line_serial < self._current_op_serial: - msg = f'Invalid NRTM serial for {self.source}: ADD/DEL has serial {line_serial}, ' \ - f'expected at least {self._current_op_serial}' + msg = ( + f"Invalid NRTM serial for {self.source}: ADD/DEL has serial {line_serial}, " + f"expected at least {self._current_op_serial}" + ) logger.error(msg) self.database_handler.record_mirror_error(self.source, msg) raise ValueError(msg) @@ -408,7 +454,13 @@ def _handle_operation(self, current_paragraph: str, paragraphs) -> None: operation = DatabaseOperation(operation_str) object_text = next(paragraphs) - nrtm_operation = NRTMOperation(self.source, operation, self._current_op_serial, - object_text, self.strict_validation_key_cert, self.rpki_aware, - self.object_class_filter) + nrtm_operation = NRTMOperation( + self.source, + operation, + self._current_op_serial, + object_text, + self.strict_validation_key_cert, + self.rpki_aware, + self.object_class_filter, + ) self.operations.append(nrtm_operation) diff --git a/irrd/mirroring/scheduler.py b/irrd/mirroring/scheduler.py index eb29cae16..4afc3e95e 100644 --- a/irrd/mirroring/scheduler.py +++ b/irrd/mirroring/scheduler.py @@ -1,19 +1,23 @@ -import time -from collections import defaultdict - import gc import logging import multiprocessing - import signal -from setproctitle import setproctitle +import time +from collections import defaultdict from typing import Dict -from irrd.conf import get_setting, RPKI_IRR_PSEUDO_SOURCE -from irrd.conf.defaults import DEFAULT_SOURCE_IMPORT_TIMER, DEFAULT_SOURCE_EXPORT_TIMER +from setproctitle import setproctitle + +from irrd.conf import RPKI_IRR_PSEUDO_SOURCE, get_setting +from irrd.conf.defaults import DEFAULT_SOURCE_EXPORT_TIMER, DEFAULT_SOURCE_IMPORT_TIMER + from .mirror_runners_export import SourceExportRunner -from .mirror_runners_import import RPSLMirrorImportUpdateRunner, ROAImportRunner, \ - ScopeFilterUpdateRunner, RoutePreferenceUpdateRunner +from .mirror_runners_import import ( + ROAImportRunner, + RoutePreferenceUpdateRunner, + RPSLMirrorImportUpdateRunner, + ScopeFilterUpdateRunner, +) logger = logging.getLogger(__name__) @@ -30,12 +34,14 @@ def close(self): # pragma: no cover close() is not available in Python 3.6, use our own implementation if needed. """ - if hasattr(super, 'close'): + if hasattr(super, "close"): return super().close() if self._popen is not None: if self._popen.poll() is None: - raise ValueError("Cannot close a process while it is still running. " - "You should first call join() or terminate().") + raise ValueError( + "Cannot close a process while it is still running. " + "You should first call join() or terminate()." + ) self._popen = None del self._sentinel self._closed = True @@ -45,7 +51,7 @@ def run(self): # (signal handlers are inherited) signal.signal(signal.SIGTERM, signal.SIG_DFL) - setproctitle(f'irrd-{self.name}') + setproctitle(f"irrd-{self.name}") self.runner.run() @@ -57,6 +63,7 @@ class MirrorScheduler: unless a process is still running for that database (which is likely to be the case in some full imports). """ + processes: Dict[str, ScheduledTaskProcess] last_started_time: Dict[str, int] @@ -69,38 +76,44 @@ def __init__(self, *args, **kwargs): self.previous_scopefilter_excluded = None def run(self) -> None: - if get_setting('database_readonly'): + if get_setting("database_readonly"): return - if get_setting('rpki.roa_source'): - import_timer = int(get_setting('rpki.roa_import_timer')) + if get_setting("rpki.roa_source"): + import_timer = int(get_setting("rpki.roa_import_timer")) self.run_if_relevant(RPKI_IRR_PSEUDO_SOURCE, ROAImportRunner, import_timer) - if get_setting("sources") and any([ - source_settings.get("route_object_preference") - for source_settings in get_setting("sources").values() - ]): - import_timer = int(get_setting('route_object_preference.update_timer')) - self.run_if_relevant('routepref', RoutePreferenceUpdateRunner, import_timer) + if get_setting("sources") and any( + [ + source_settings.get("route_object_preference") + for source_settings in get_setting("sources").values() + ] + ): + import_timer = int(get_setting("route_object_preference.update_timer")) + self.run_if_relevant("routepref", RoutePreferenceUpdateRunner, import_timer) if self._check_scopefilter_change(): - self.run_if_relevant('scopefilter', ScopeFilterUpdateRunner, 0) + self.run_if_relevant("scopefilter", ScopeFilterUpdateRunner, 0) sources_started = 0 - for source in get_setting('sources', {}).keys(): + for source in get_setting("sources", {}).keys(): if sources_started >= MAX_SIMULTANEOUS_RUNS: break started_import = False started_export = False - is_mirror = get_setting(f'sources.{source}.import_source') or get_setting(f'sources.{source}.nrtm_host') - import_timer = int(get_setting(f'sources.{source}.import_timer', DEFAULT_SOURCE_IMPORT_TIMER)) + is_mirror = get_setting(f"sources.{source}.import_source") or get_setting( + f"sources.{source}.nrtm_host" + ) + import_timer = int(get_setting(f"sources.{source}.import_timer", DEFAULT_SOURCE_IMPORT_TIMER)) if is_mirror: started_import = self.run_if_relevant(source, RPSLMirrorImportUpdateRunner, import_timer) - runs_export = get_setting(f'sources.{source}.export_destination') or get_setting(f'sources.{source}.export_destination_unfiltered') - export_timer = int(get_setting(f'sources.{source}.export_timer', DEFAULT_SOURCE_EXPORT_TIMER)) + runs_export = get_setting(f"sources.{source}.export_destination") or get_setting( + f"sources.{source}.export_destination_unfiltered" + ) + export_timer = int(get_setting(f"sources.{source}.export_timer", DEFAULT_SOURCE_EXPORT_TIMER)) if runs_export: started_export = self.run_if_relevant(source, SourceExportRunner, export_timer) @@ -113,22 +126,24 @@ def _check_scopefilter_change(self) -> bool: Check whether the scope filter has changed since last call. Always returns True on the first call. """ - if not get_setting('scopefilter'): + if not get_setting("scopefilter"): return False - current_prefixes = list(get_setting('scopefilter.prefixes', [])) - current_asns = list(get_setting('scopefilter.asns', [])) + current_prefixes = list(get_setting("scopefilter.prefixes", [])) + current_asns = list(get_setting("scopefilter.asns", [])) current_exclusions = { name - for name, settings in get_setting('sources', {}).items() - if settings.get('scopefilter_excluded') + for name, settings in get_setting("sources", {}).items() + if settings.get("scopefilter_excluded") } - if any([ - self.previous_scopefilter_prefixes != current_prefixes, - self.previous_scopefilter_asns != current_asns, - self.previous_scopefilter_excluded != current_exclusions, - ]): + if any( + [ + self.previous_scopefilter_prefixes != current_prefixes, + self.previous_scopefilter_asns != current_asns, + self.previous_scopefilter_excluded != current_exclusions, + ] + ): self.previous_scopefilter_prefixes = current_prefixes self.previous_scopefilter_asns = current_asns self.previous_scopefilter_excluded = current_exclusions @@ -142,7 +157,7 @@ def run_if_relevant(self, source: str, runner_class, timer: int) -> bool: if not has_expired or process_name in self.processes: return False - logger.debug(f'Started new process {process_name} for mirror import/export for {source}') + logger.debug(f"Started new process {process_name} for mirror import/export for {source}") initiator = runner_class(source=source) process = ScheduledTaskProcess(runner=initiator, name=process_name) self.processes[process_name] = process @@ -151,7 +166,7 @@ def run_if_relevant(self, source: str, runner_class, timer: int) -> bool: return True def terminate_children(self) -> None: # pragma: no cover - logger.info('MirrorScheduler terminating children') + logger.info("MirrorScheduler terminating children") for process in self.processes.values(): try: process.terminate() @@ -168,8 +183,9 @@ def update_process_state(self): try: process.close() except Exception as e: # pragma: no cover - logging.error(f'Failed to close {process_name} (pid {process.pid}), ' - f'possible resource leak: {e}') + logging.error( + f"Failed to close {process_name} (pid {process.pid}), possible resource leak: {e}" + ) del self.processes[process_name] gc_collect_needed = True if gc_collect_needed: diff --git a/irrd/mirroring/tests/test_mirror_runners_export.py b/irrd/mirroring/tests/test_mirror_runners_export.py index 2dfbabd8c..34dd250d5 100644 --- a/irrd/mirroring/tests/test_mirror_runners_export.py +++ b/irrd/mirroring/tests/test_mirror_runners_export.py @@ -4,181 +4,195 @@ from pathlib import Path from unittest.mock import Mock -from irrd.rpki.status import RPKIStatus from irrd.routepref.status import RoutePreferenceStatus +from irrd.rpki.status import RPKIStatus from irrd.scopefilter.status import ScopeFilterStatus from irrd.utils.test_utils import flatten_mock_calls -from ..mirror_runners_export import SourceExportRunner, EXPORT_PERMISSIONS + +from ..mirror_runners_export import EXPORT_PERMISSIONS, SourceExportRunner class TestSourceExportRunner: def test_export(self, tmpdir, config_override, monkeypatch, caplog): - config_override({ - 'sources': { - 'TEST': { - 'export_destination': str(tmpdir), + config_override( + { + "sources": { + "TEST": { + "export_destination": str(tmpdir), + } } } - }) + ) mock_dh = Mock() mock_dq = Mock() mock_dsq = Mock() - monkeypatch.setattr('irrd.mirroring.mirror_runners_export.DatabaseHandler', lambda: mock_dh) - monkeypatch.setattr('irrd.mirroring.mirror_runners_export.RPSLDatabaseQuery', lambda: mock_dq) - monkeypatch.setattr('irrd.mirroring.mirror_runners_export.DatabaseStatusQuery', lambda: mock_dsq) + monkeypatch.setattr("irrd.mirroring.mirror_runners_export.DatabaseHandler", lambda: mock_dh) + monkeypatch.setattr("irrd.mirroring.mirror_runners_export.RPSLDatabaseQuery", lambda: mock_dq) + monkeypatch.setattr("irrd.mirroring.mirror_runners_export.DatabaseStatusQuery", lambda: mock_dsq) - responses = cycle([ - repeat({'serial_newest_seen': '424242'}), + responses = cycle( [ - # The CRYPT-PW hash must not appear in the output - {'object_text': 'object 1 🦄\nauth: CRYPT-PW foobar\n'}, - {'object_text': 'object 2 🌈\n'}, - ], - ]) + repeat({"serial_newest_seen": "424242"}), + [ + # The CRYPT-PW hash must not appear in the output + {"object_text": "object 1 🦄\nauth: CRYPT-PW foobar\n"}, + {"object_text": "object 2 🌈\n"}, + ], + ] + ) mock_dh.execute_query = lambda q: next(responses) - runner = SourceExportRunner('TEST') + runner = SourceExportRunner("TEST") runner.run() runner.run() - serial_filename = tmpdir + '/TEST.CURRENTSERIAL' + serial_filename = tmpdir + "/TEST.CURRENTSERIAL" assert oct(os.lstat(serial_filename).st_mode)[-3:] == oct(EXPORT_PERMISSIONS)[-3:] with open(serial_filename) as fh: - assert fh.read() == '424242' + assert fh.read() == "424242" - export_filename = tmpdir + '/test.db.gz' + export_filename = tmpdir + "/test.db.gz" assert oct(os.lstat(export_filename).st_mode)[-3:] == oct(EXPORT_PERMISSIONS)[-3:] with gzip.open(export_filename) as fh: - assert fh.read().decode('utf-8') == 'object 1 🦄\nauth: CRYPT-PW DummyValue # Filtered for security\n\n' \ - 'object 2 🌈\n\n# EOF\n' + assert ( + fh.read().decode("utf-8") + == "object 1 🦄\nauth: CRYPT-PW DummyValue # Filtered for security\n\nobject 2 🌈\n\n# EOF\n" + ) assert flatten_mock_calls(mock_dh) == [ - ['record_serial_exported', ('TEST', '424242'), {}], - ['commit', (), {}], - ['close', (), {}], - ['record_serial_exported', ('TEST', '424242'), {}], - ['commit', (), {}], - ['close', (), {}] + ["record_serial_exported", ("TEST", "424242"), {}], + ["commit", (), {}], + ["close", (), {}], + ["record_serial_exported", ("TEST", "424242"), {}], + ["commit", (), {}], + ["close", (), {}], ] assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST'],), {}], - ['rpki_status', ([RPKIStatus.not_found, RPKIStatus.valid],), {}], - ['scopefilter_status', ([ScopeFilterStatus.in_scope],), {}], - ['route_preference_status', ([RoutePreferenceStatus.visible],), {}], - ['sources', (['TEST'],), {}], - ['rpki_status', ([RPKIStatus.not_found, RPKIStatus.valid],), {}], - ['scopefilter_status', ([ScopeFilterStatus.in_scope],), {}], - ['route_preference_status', ([RoutePreferenceStatus.visible],), {}], + ["sources", (["TEST"],), {}], + ["rpki_status", ([RPKIStatus.not_found, RPKIStatus.valid],), {}], + ["scopefilter_status", ([ScopeFilterStatus.in_scope],), {}], + ["route_preference_status", ([RoutePreferenceStatus.visible],), {}], + ["sources", (["TEST"],), {}], + ["rpki_status", ([RPKIStatus.not_found, RPKIStatus.valid],), {}], + ["scopefilter_status", ([ScopeFilterStatus.in_scope],), {}], + ["route_preference_status", ([RoutePreferenceStatus.visible],), {}], ] - assert 'Starting a source export for TEST' in caplog.text - assert 'Export for TEST complete' in caplog.text + assert "Starting a source export for TEST" in caplog.text + assert "Export for TEST complete" in caplog.text def test_export_unfiltered(self, tmpdir, config_override, monkeypatch, caplog): - config_override({ - 'sources': { - 'TEST': { - 'export_destination_unfiltered': str(tmpdir), + config_override( + { + "sources": { + "TEST": { + "export_destination_unfiltered": str(tmpdir), + } } } - }) + ) mock_dh = Mock() mock_dq = Mock() mock_dsq = Mock() - monkeypatch.setattr('irrd.mirroring.mirror_runners_export.DatabaseHandler', lambda: mock_dh) - monkeypatch.setattr('irrd.mirroring.mirror_runners_export.RPSLDatabaseQuery', lambda: mock_dq) - monkeypatch.setattr('irrd.mirroring.mirror_runners_export.DatabaseStatusQuery', lambda: mock_dsq) + monkeypatch.setattr("irrd.mirroring.mirror_runners_export.DatabaseHandler", lambda: mock_dh) + monkeypatch.setattr("irrd.mirroring.mirror_runners_export.RPSLDatabaseQuery", lambda: mock_dq) + monkeypatch.setattr("irrd.mirroring.mirror_runners_export.DatabaseStatusQuery", lambda: mock_dsq) - responses = cycle([ - repeat({'serial_newest_seen': '424242'}), + responses = cycle( [ - # The CRYPT-PW hash should appear in the output - {'object_text': 'object 1 🦄\nauth: CRYPT-PW foobar\n'}, - {'object_text': 'object 2 🌈\n'}, - ], - ]) + repeat({"serial_newest_seen": "424242"}), + [ + # The CRYPT-PW hash should appear in the output + {"object_text": "object 1 🦄\nauth: CRYPT-PW foobar\n"}, + {"object_text": "object 2 🌈\n"}, + ], + ] + ) mock_dh.execute_query = lambda q: next(responses) - runner = SourceExportRunner('TEST') + runner = SourceExportRunner("TEST") runner.run() - serial_filename = tmpdir + '/TEST.CURRENTSERIAL' + serial_filename = tmpdir + "/TEST.CURRENTSERIAL" assert oct(os.lstat(serial_filename).st_mode)[-3:] == oct(EXPORT_PERMISSIONS)[-3:] with open(serial_filename) as fh: - assert fh.read() == '424242' + assert fh.read() == "424242" - export_filename = tmpdir + '/test.db.gz' + export_filename = tmpdir + "/test.db.gz" assert oct(os.lstat(export_filename).st_mode)[-3:] == oct(EXPORT_PERMISSIONS)[-3:] with gzip.open(export_filename) as fh: - assert fh.read().decode('utf-8') == 'object 1 🦄\nauth: CRYPT-PW foobar\n\n' \ - 'object 2 🌈\n\n# EOF\n' + assert fh.read().decode("utf-8") == "object 1 🦄\nauth: CRYPT-PW foobar\n\nobject 2 🌈\n\n# EOF\n" def test_failure(self, tmpdir, config_override, monkeypatch, caplog): - config_override({ - 'sources': { - 'TEST': { - 'export_destination': str(tmpdir), + config_override( + { + "sources": { + "TEST": { + "export_destination": str(tmpdir), + } } } - }) + ) mock_dh = Mock() mock_dsq = Mock() - monkeypatch.setattr('irrd.mirroring.mirror_runners_export.DatabaseHandler', lambda: mock_dh) - monkeypatch.setattr('irrd.mirroring.mirror_runners_export.DatabaseStatusQuery', lambda: mock_dsq) - mock_dh.execute_query = Mock(side_effect=ValueError('expected-test-error')) + monkeypatch.setattr("irrd.mirroring.mirror_runners_export.DatabaseHandler", lambda: mock_dh) + monkeypatch.setattr("irrd.mirroring.mirror_runners_export.DatabaseStatusQuery", lambda: mock_dsq) + mock_dh.execute_query = Mock(side_effect=ValueError("expected-test-error")) - runner = SourceExportRunner('TEST') + runner = SourceExportRunner("TEST") runner.run() - assert 'An exception occurred while attempting to run an export for TEST' in caplog.text - assert 'expected-test-error' in caplog.text + assert "An exception occurred while attempting to run an export for TEST" in caplog.text + assert "expected-test-error" in caplog.text def test_export_no_serial(self, tmpdir, config_override, monkeypatch, caplog): - config_override({ - 'sources': { - 'TEST': { - 'export_destination': str(tmpdir), + config_override( + { + "sources": { + "TEST": { + "export_destination": str(tmpdir), + } } } - }) + ) mock_dh = Mock() mock_dq = Mock() mock_dsq = Mock() - monkeypatch.setattr('irrd.mirroring.mirror_runners_export.DatabaseHandler', - lambda: mock_dh) - monkeypatch.setattr('irrd.mirroring.mirror_runners_export.RPSLDatabaseQuery', - lambda: mock_dq) - monkeypatch.setattr('irrd.mirroring.mirror_runners_export.DatabaseStatusQuery', - lambda: mock_dsq) + monkeypatch.setattr("irrd.mirroring.mirror_runners_export.DatabaseHandler", lambda: mock_dh) + monkeypatch.setattr("irrd.mirroring.mirror_runners_export.RPSLDatabaseQuery", lambda: mock_dq) + monkeypatch.setattr("irrd.mirroring.mirror_runners_export.DatabaseStatusQuery", lambda: mock_dsq) - responses = cycle([ - iter([]), + responses = cycle( [ - # The CRYPT-PW hash must not appear in the output - {'object_text': 'object 1 🦄\nauth: CRYPT-PW foobar\n'}, - {'object_text': 'object 2 🌈\n'}, - ], - ]) + iter([]), + [ + # The CRYPT-PW hash must not appear in the output + {"object_text": "object 1 🦄\nauth: CRYPT-PW foobar\n"}, + {"object_text": "object 2 🌈\n"}, + ], + ] + ) mock_dh.execute_query = lambda q: next(responses) - runner = SourceExportRunner('TEST') + runner = SourceExportRunner("TEST") runner.run() runner.run() - serial_filename = Path(tmpdir + '/TEST.CURRENTSERIAL') + serial_filename = Path(tmpdir + "/TEST.CURRENTSERIAL") assert not serial_filename.exists() - export_filename = tmpdir + '/test.db.gz' + export_filename = tmpdir + "/test.db.gz" with gzip.open(export_filename) as fh: - assert fh.read().decode( - 'utf-8') == 'object 1 🦄\nauth: CRYPT-PW DummyValue # Filtered for security\n\n' \ - 'object 2 🌈\n\n# EOF\n' + assert ( + fh.read().decode("utf-8") + == "object 1 🦄\nauth: CRYPT-PW DummyValue # Filtered for security\n\nobject 2 🌈\n\n# EOF\n" + ) - assert 'Starting a source export for TEST' in caplog.text - assert 'Export for TEST complete' in caplog.text + assert "Starting a source export for TEST" in caplog.text + assert "Export for TEST complete" in caplog.text diff --git a/irrd/mirroring/tests/test_mirror_runners_import.py b/irrd/mirroring/tests/test_mirror_runners_import.py index 0c25d367b..ad8abc94e 100644 --- a/irrd/mirroring/tests/test_mirror_runners_import.py +++ b/irrd/mirroring/tests/test_mirror_runners_import.py @@ -6,15 +6,20 @@ import pytest +from irrd.routepref.routepref import update_route_preference_status from irrd.rpki.importer import ROAParserException from irrd.rpki.validators import BulkRouteROAValidator -from irrd.routepref.routepref import update_route_preference_status -from irrd.storage.database_handler import DatabaseHandler from irrd.scopefilter.validators import ScopeFilterValidator +from irrd.storage.database_handler import DatabaseHandler from irrd.utils.test_utils import flatten_mock_calls + from ..mirror_runners_import import ( - RPSLMirrorImportUpdateRunner, RPSLMirrorFullImportRunner, NRTMImportUpdateStreamRunner, - ROAImportRunner, ScopeFilterUpdateRunner, RoutePreferenceUpdateRunner + NRTMImportUpdateStreamRunner, + ROAImportRunner, + RoutePreferenceUpdateRunner, + RPSLMirrorFullImportRunner, + RPSLMirrorImportUpdateRunner, + ScopeFilterUpdateRunner, ) @@ -24,71 +29,84 @@ def test_full_import_call(self, monkeypatch): mock_dq = Mock() mock_full_import_runner = Mock() - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseHandler', lambda: mock_dh) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseStatusQuery', lambda: mock_dq) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.RPSLMirrorFullImportRunner', lambda source: mock_full_import_runner) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseHandler", lambda: mock_dh) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseStatusQuery", lambda: mock_dq) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.RPSLMirrorFullImportRunner", + lambda source: mock_full_import_runner, + ) mock_dh.execute_query = lambda q: iter([]) - runner = RPSLMirrorImportUpdateRunner(source='TEST') + runner = RPSLMirrorImportUpdateRunner(source="TEST") runner.run() - assert flatten_mock_calls(mock_dq) == [['source', ('TEST',), {}]] - assert flatten_mock_calls(mock_dh) == [['commit', (), {}], ['close', (), {}]] + assert flatten_mock_calls(mock_dq) == [["source", ("TEST",), {}]] + assert flatten_mock_calls(mock_dh) == [["commit", (), {}], ["close", (), {}]] assert len(mock_full_import_runner.mock_calls) == 1 - assert mock_full_import_runner.mock_calls[0][0] == 'run' + assert mock_full_import_runner.mock_calls[0][0] == "run" def test_force_reload(self, monkeypatch, config_override): - config_override({ - 'sources': { - 'TEST': { - 'nrtm_host': '192.0.2.1', + config_override( + { + "sources": { + "TEST": { + "nrtm_host": "192.0.2.1", + } } } - }) + ) mock_dh = Mock() mock_dq = Mock() mock_full_import_runner = Mock() - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseHandler', lambda: mock_dh) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseStatusQuery', lambda: mock_dq) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.RPSLMirrorFullImportRunner', lambda source: mock_full_import_runner) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseHandler", lambda: mock_dh) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseStatusQuery", lambda: mock_dq) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.RPSLMirrorFullImportRunner", + lambda source: mock_full_import_runner, + ) - mock_dh.execute_query = lambda q: iter([{'serial_newest_mirror': 424242, 'force_reload': True}]) - runner = RPSLMirrorImportUpdateRunner(source='TEST') + mock_dh.execute_query = lambda q: iter([{"serial_newest_mirror": 424242, "force_reload": True}]) + runner = RPSLMirrorImportUpdateRunner(source="TEST") runner.run() - assert flatten_mock_calls(mock_dq) == [['source', ('TEST',), {}]] - assert flatten_mock_calls(mock_dh) == [['commit', (), {}], ['close', (), {}]] + assert flatten_mock_calls(mock_dq) == [["source", ("TEST",), {}]] + assert flatten_mock_calls(mock_dh) == [["commit", (), {}], ["close", (), {}]] assert len(mock_full_import_runner.mock_calls) == 1 - assert mock_full_import_runner.mock_calls[0][0] == 'run' + assert mock_full_import_runner.mock_calls[0][0] == "run" def test_update_stream_call(self, monkeypatch, config_override): - config_override({ - 'sources': { - 'TEST': { - 'nrtm_host': '192.0.2.1', + config_override( + { + "sources": { + "TEST": { + "nrtm_host": "192.0.2.1", + } } } - }) + ) mock_dh = Mock() mock_dq = Mock() mock_stream_runner = Mock() - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseHandler', lambda: mock_dh) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseStatusQuery', lambda: mock_dq) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.NRTMImportUpdateStreamRunner', lambda source: mock_stream_runner) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseHandler", lambda: mock_dh) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseStatusQuery", lambda: mock_dq) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.NRTMImportUpdateStreamRunner", + lambda source: mock_stream_runner, + ) - mock_dh.execute_query = lambda q: iter([{'serial_newest_mirror': 424242, 'force_reload': False}]) - runner = RPSLMirrorImportUpdateRunner(source='TEST') + mock_dh.execute_query = lambda q: iter([{"serial_newest_mirror": 424242, "force_reload": False}]) + runner = RPSLMirrorImportUpdateRunner(source="TEST") runner.run() - assert flatten_mock_calls(mock_dq) == [['source', ('TEST',), {}]] - assert flatten_mock_calls(mock_dh) == [['commit', (), {}], ['close', (), {}]] + assert flatten_mock_calls(mock_dq) == [["source", ("TEST",), {}]] + assert flatten_mock_calls(mock_dh) == [["commit", (), {}], ["close", (), {}]] assert len(mock_stream_runner.mock_calls) == 1 - assert mock_stream_runner.mock_calls[0][0] == 'run' + assert mock_stream_runner.mock_calls[0][0] == "run" assert mock_stream_runner.mock_calls[0][1] == (424242,) def test_io_exception_handling(self, monkeypatch, caplog): @@ -96,270 +114,315 @@ def test_io_exception_handling(self, monkeypatch, caplog): mock_dq = Mock() mock_full_import_runner = Mock() - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseHandler', lambda: mock_dh) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseStatusQuery', lambda: mock_dq) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.RPSLMirrorFullImportRunner', lambda source: mock_full_import_runner) - mock_full_import_runner.run = Mock(side_effect=ConnectionResetError('test-error')) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseHandler", lambda: mock_dh) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseStatusQuery", lambda: mock_dq) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.RPSLMirrorFullImportRunner", + lambda source: mock_full_import_runner, + ) + mock_full_import_runner.run = Mock(side_effect=ConnectionResetError("test-error")) - mock_dh.execute_query = lambda q: iter([{'serial_newest_mirror': 424242, 'force_reload': False}]) - runner = RPSLMirrorImportUpdateRunner(source='TEST') + mock_dh.execute_query = lambda q: iter([{"serial_newest_mirror": 424242, "force_reload": False}]) + runner = RPSLMirrorImportUpdateRunner(source="TEST") runner.run() - assert flatten_mock_calls(mock_dh) == [['close', (), {}]] - assert 'An error occurred while attempting a mirror update or initial import for TEST' in caplog.text - assert 'test-error' in caplog.text - assert 'Traceback' not in caplog.text + assert flatten_mock_calls(mock_dh) == [["close", (), {}]] + assert "An error occurred while attempting a mirror update or initial import for TEST" in caplog.text + assert "test-error" in caplog.text + assert "Traceback" not in caplog.text def test_unexpected_exception_handling(self, monkeypatch, caplog): mock_dh = Mock() mock_dq = Mock() mock_full_import_runner = Mock() - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseHandler', lambda: mock_dh) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseStatusQuery', lambda: mock_dq) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.RPSLMirrorFullImportRunner', lambda source: mock_full_import_runner) - mock_full_import_runner.run = Mock(side_effect=Exception('test-error')) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseHandler", lambda: mock_dh) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseStatusQuery", lambda: mock_dq) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.RPSLMirrorFullImportRunner", + lambda source: mock_full_import_runner, + ) + mock_full_import_runner.run = Mock(side_effect=Exception("test-error")) - mock_dh.execute_query = lambda q: iter([{'serial_newest_mirror': 424242, 'force_reload': False}]) - runner = RPSLMirrorImportUpdateRunner(source='TEST') + mock_dh.execute_query = lambda q: iter([{"serial_newest_mirror": 424242, "force_reload": False}]) + runner = RPSLMirrorImportUpdateRunner(source="TEST") runner.run() - assert flatten_mock_calls(mock_dh) == [['close', (), {}]] - assert 'An exception occurred while attempting a mirror update or initial import for TEST' in caplog.text - assert 'test-error' in caplog.text - assert 'Traceback' in caplog.text + assert flatten_mock_calls(mock_dh) == [["close", (), {}]] + assert ( + "An exception occurred while attempting a mirror update or initial import for TEST" in caplog.text + ) + assert "test-error" in caplog.text + assert "Traceback" in caplog.text class TestRPSLMirrorFullImportRunner: def test_run_import_ftp(self, monkeypatch, config_override): - config_override({ - 'rpki': {'roa_source': 'https://example.com/roa.json'}, - 'sources': { - 'TEST': { - 'import_source': ['ftp://host/source1.gz', 'ftp://host/source2'], - 'import_serial_source': 'ftp://host/serial', - } + config_override( + { + "rpki": {"roa_source": "https://example.com/roa.json"}, + "sources": { + "TEST": { + "import_source": ["ftp://host/source1.gz", "ftp://host/source2"], + "import_serial_source": "ftp://host/serial", + } + }, } - }) + ) mock_dh = Mock() request = Mock() MockMirrorFileImportParser.rpsl_data_calls = [] - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.MirrorFileImportParser', MockMirrorFileImportParser) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.request', request) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.MirrorFileImportParser", MockMirrorFileImportParser + ) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.request", request) mock_bulk_validator_init = Mock() - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.BulkRouteROAValidator', mock_bulk_validator_init) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.BulkRouteROAValidator", mock_bulk_validator_init + ) responses = { # gzipped data, contains 'source1' - 'ftp://host/source1.gz': b64decode('H4sIAE4CfFsAAyvOLy1KTjUEAE5Fj0oHAAAA'), - 'ftp://host/source2': b'source2', - 'ftp://host/serial': b'424242', + "ftp://host/source1.gz": b64decode("H4sIAE4CfFsAAyvOLy1KTjUEAE5Fj0oHAAAA"), + "ftp://host/source2": b"source2", + "ftp://host/serial": b"424242", } request.urlopen = lambda url: MockUrlopenResponse(responses[url]) - RPSLMirrorFullImportRunner('TEST').run(mock_dh, serial_newest_mirror=424241) + RPSLMirrorFullImportRunner("TEST").run(mock_dh, serial_newest_mirror=424241) - assert MockMirrorFileImportParser.rpsl_data_calls == ['source1', 'source2'] + assert MockMirrorFileImportParser.rpsl_data_calls == ["source1", "source2"] assert flatten_mock_calls(mock_dh) == [ - ['delete_all_rpsl_objects_with_journal', ('TEST',), {}], - ['disable_journaling', (), {}], - ['record_serial_newest_mirror', ('TEST', 424242), {}], + ["delete_all_rpsl_objects_with_journal", ("TEST",), {}], + ["disable_journaling", (), {}], + ["record_serial_newest_mirror", ("TEST", 424242), {}], ] assert mock_bulk_validator_init.mock_calls[0][1][0] == mock_dh def test_failed_import_ftp(self, monkeypatch, config_override): - config_override({ - 'rpki': {'roa_source': 'https://example.com/roa.json'}, - 'sources': { - 'TEST': { - 'import_source': 'ftp://host/source1.gz', - } + config_override( + { + "rpki": {"roa_source": "https://example.com/roa.json"}, + "sources": { + "TEST": { + "import_source": "ftp://host/source1.gz", + } + }, } - }) + ) mock_dh = Mock() request = Mock() MockMirrorFileImportParser.rpsl_data_calls = [] - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.MirrorFileImportParser', MockMirrorFileImportParser) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.request', request) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.MirrorFileImportParser", MockMirrorFileImportParser + ) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.request", request) mock_bulk_validator_init = Mock() - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.BulkRouteROAValidator', mock_bulk_validator_init) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.BulkRouteROAValidator", mock_bulk_validator_init + ) - request.urlopen = lambda url: MockUrlopenResponse(b'', fail=True) + request.urlopen = lambda url: MockUrlopenResponse(b"", fail=True) with pytest.raises(IOError): - RPSLMirrorFullImportRunner('TEST').run(mock_dh, serial_newest_mirror=424241) + RPSLMirrorFullImportRunner("TEST").run(mock_dh, serial_newest_mirror=424241) def test_run_import_local_file(self, monkeypatch, config_override, tmpdir): - tmp_import_source1 = tmpdir + '/source1.rpsl.gz' - with open(tmp_import_source1, 'wb') as fh: + tmp_import_source1 = tmpdir + "/source1.rpsl.gz" + with open(tmp_import_source1, "wb") as fh: # gzipped data, contains 'source1' - fh.write(b64decode('H4sIAE4CfFsAAyvOLy1KTjUEAE5Fj0oHAAAA')) - tmp_import_source2 = tmpdir + '/source2.rpsl' - with open(tmp_import_source2, 'w') as fh: - fh.write('source2') - tmp_import_serial = tmpdir + '/serial' - with open(tmp_import_serial, 'w') as fh: - fh.write('424242') - - config_override({ - 'rpki': {'roa_source': None}, - 'sources': { - 'TEST': { - 'import_source': ['file://' + str(tmp_import_source1), 'file://' + str(tmp_import_source2)], - 'import_serial_source': 'file://' + str(tmp_import_serial), - } + fh.write(b64decode("H4sIAE4CfFsAAyvOLy1KTjUEAE5Fj0oHAAAA")) + tmp_import_source2 = tmpdir + "/source2.rpsl" + with open(tmp_import_source2, "w") as fh: + fh.write("source2") + tmp_import_serial = tmpdir + "/serial" + with open(tmp_import_serial, "w") as fh: + fh.write("424242") + + config_override( + { + "rpki": {"roa_source": None}, + "sources": { + "TEST": { + "import_source": [ + "file://" + str(tmp_import_source1), + "file://" + str(tmp_import_source2), + ], + "import_serial_source": "file://" + str(tmp_import_serial), + } + }, } - }) + ) mock_dh = Mock() MockMirrorFileImportParser.rpsl_data_calls = [] - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.MirrorFileImportParser', MockMirrorFileImportParser) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.MirrorFileImportParser", MockMirrorFileImportParser + ) - RPSLMirrorFullImportRunner('TEST').run(mock_dh) + RPSLMirrorFullImportRunner("TEST").run(mock_dh) - assert MockMirrorFileImportParser.rpsl_data_calls == ['source1', 'source2'] + assert MockMirrorFileImportParser.rpsl_data_calls == ["source1", "source2"] assert flatten_mock_calls(mock_dh) == [ - ['delete_all_rpsl_objects_with_journal', ('TEST',), {}], - ['disable_journaling', (), {}], - ['record_serial_newest_mirror', ('TEST', 424242), {}], + ["delete_all_rpsl_objects_with_journal", ("TEST",), {}], + ["disable_journaling", (), {}], + ["record_serial_newest_mirror", ("TEST", 424242), {}], ] def test_no_serial_ftp(self, monkeypatch, config_override): - config_override({ - 'rpki': {'roa_source': None}, - 'sources': { - 'TEST': { - 'import_source': ['ftp://host/source1.gz', 'ftp://host/source2'], - } + config_override( + { + "rpki": {"roa_source": None}, + "sources": { + "TEST": { + "import_source": ["ftp://host/source1.gz", "ftp://host/source2"], + } + }, } - }) + ) mock_dh = Mock() request = Mock() MockMirrorFileImportParser.rpsl_data_calls = [] - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.MirrorFileImportParser', MockMirrorFileImportParser) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.request', request) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.MirrorFileImportParser", MockMirrorFileImportParser + ) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.request", request) responses = { # gzipped data, contains 'source1' - 'ftp://host/source1.gz': b64decode('H4sIAE4CfFsAAyvOLy1KTjUEAE5Fj0oHAAAA'), - 'ftp://host/source2': b'source2', + "ftp://host/source1.gz": b64decode("H4sIAE4CfFsAAyvOLy1KTjUEAE5Fj0oHAAAA"), + "ftp://host/source2": b"source2", } request.urlopen = lambda url: MockUrlopenResponse(responses[url]) - RPSLMirrorFullImportRunner('TEST').run(mock_dh, serial_newest_mirror=42) + RPSLMirrorFullImportRunner("TEST").run(mock_dh, serial_newest_mirror=42) - assert MockMirrorFileImportParser.rpsl_data_calls == ['source1', 'source2'] + assert MockMirrorFileImportParser.rpsl_data_calls == ["source1", "source2"] assert flatten_mock_calls(mock_dh) == [ - ['delete_all_rpsl_objects_with_journal', ('TEST',), {}], - ['disable_journaling', (), {}], + ["delete_all_rpsl_objects_with_journal", ("TEST",), {}], + ["disable_journaling", (), {}], ] def test_import_cancelled_serial_too_old(self, monkeypatch, config_override, caplog): - config_override({ - 'sources': { - 'TEST': { - 'import_source': ['ftp://host/source1.gz', 'ftp://host/source2'], - 'import_serial_source': 'ftp://host/serial', + config_override( + { + "sources": { + "TEST": { + "import_source": ["ftp://host/source1.gz", "ftp://host/source2"], + "import_serial_source": "ftp://host/serial", + } } } - }) + ) mock_dh = Mock() request = Mock() MockMirrorFileImportParser.rpsl_data_calls = [] - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.MirrorFileImportParser', MockMirrorFileImportParser) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.request', request) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.MirrorFileImportParser", MockMirrorFileImportParser + ) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.request", request) responses = { # gzipped data, contains 'source1' - 'ftp://host/source1.gz': b64decode('H4sIAE4CfFsAAyvOLy1KTjUEAE5Fj0oHAAAA'), - 'ftp://host/source2': b'source2', - 'ftp://host/serial': b'424242', + "ftp://host/source1.gz": b64decode("H4sIAE4CfFsAAyvOLy1KTjUEAE5Fj0oHAAAA"), + "ftp://host/source2": b"source2", + "ftp://host/serial": b"424242", } request.urlopen = lambda url: MockUrlopenResponse(responses[url]) - RPSLMirrorFullImportRunner('TEST').run(mock_dh, serial_newest_mirror=424243) + RPSLMirrorFullImportRunner("TEST").run(mock_dh, serial_newest_mirror=424243) assert not MockMirrorFileImportParser.rpsl_data_calls assert flatten_mock_calls(mock_dh) == [] - assert 'Current newest serial seen for TEST is 424243, import_serial is 424242, cancelling import.' + assert "Current newest serial seen for TEST is 424243, import_serial is 424242, cancelling import." def test_import_force_reload_with_serial_too_old(self, monkeypatch, config_override): - config_override({ - 'rpki': {'roa_source': None}, - 'sources': { - 'TEST': { - 'import_source': ['ftp://host/source1.gz', 'ftp://host/source2'], - 'import_serial_source': 'ftp://host/serial', - } + config_override( + { + "rpki": {"roa_source": None}, + "sources": { + "TEST": { + "import_source": ["ftp://host/source1.gz", "ftp://host/source2"], + "import_serial_source": "ftp://host/serial", + } + }, } - }) + ) mock_dh = Mock() request = Mock() MockMirrorFileImportParser.rpsl_data_calls = [] - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.MirrorFileImportParser', MockMirrorFileImportParser) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.request', request) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.MirrorFileImportParser", MockMirrorFileImportParser + ) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.request", request) responses = { # gzipped data, contains 'source1' - 'ftp://host/source1.gz': b64decode('H4sIAE4CfFsAAyvOLy1KTjUEAE5Fj0oHAAAA'), - 'ftp://host/source2': b'source2', - 'ftp://host/serial': b'424242', + "ftp://host/source1.gz": b64decode("H4sIAE4CfFsAAyvOLy1KTjUEAE5Fj0oHAAAA"), + "ftp://host/source2": b"source2", + "ftp://host/serial": b"424242", } request.urlopen = lambda url: MockUrlopenResponse(responses[url]) - RPSLMirrorFullImportRunner('TEST').run(mock_dh, serial_newest_mirror=424243, force_reload=True) + RPSLMirrorFullImportRunner("TEST").run(mock_dh, serial_newest_mirror=424243, force_reload=True) - assert MockMirrorFileImportParser.rpsl_data_calls == ['source1', 'source2'] + assert MockMirrorFileImportParser.rpsl_data_calls == ["source1", "source2"] assert flatten_mock_calls(mock_dh) == [ - ['delete_all_rpsl_objects_with_journal', ('TEST',), {}], - ['disable_journaling', (), {}], - ['record_serial_newest_mirror', ('TEST', 424242), {}], + ["delete_all_rpsl_objects_with_journal", ("TEST",), {}], + ["disable_journaling", (), {}], + ["record_serial_newest_mirror", ("TEST", 424242), {}], ] def test_missing_source_settings_ftp(self, config_override): - config_override({ - 'sources': { - 'TEST': { - 'import_serial_source': 'ftp://host/serial', + config_override( + { + "sources": { + "TEST": { + "import_serial_source": "ftp://host/serial", + } } } - }) + ) mock_dh = Mock() - RPSLMirrorFullImportRunner('TEST').run(mock_dh) + RPSLMirrorFullImportRunner("TEST").run(mock_dh) assert not flatten_mock_calls(mock_dh) def test_unsupported_protocol(self, config_override): - config_override({ - 'sources': { - 'TEST': { - 'import_source': 'ftp://host/source1.gz', - 'import_serial_source': 'gopher://host/serial', + config_override( + { + "sources": { + "TEST": { + "import_source": "ftp://host/source1.gz", + "import_serial_source": "gopher://host/serial", + } } } - }) + ) mock_dh = Mock() with pytest.raises(ValueError) as ve: - RPSLMirrorFullImportRunner('TEST').run(mock_dh) - assert 'scheme gopher is not supported' in str(ve.value) + RPSLMirrorFullImportRunner("TEST").run(mock_dh) + assert "scheme gopher is not supported" in str(ve.value) class MockUrlopenResponse(BytesIO): - def __init__(self, bytes: bytes, fail: bool=False): + def __init__(self, bytes: bytes, fail: bool = False): if fail: - raise URLError('error') + raise URLError("error") super().__init__(bytes) class MockMirrorFileImportParser: rpsl_data_calls: List[str] = [] - def __init__(self, source, filename, serial, database_handler, direct_error_return=False, roa_validator=None): + def __init__( + self, source, filename, serial, database_handler, direct_error_return=False, roa_validator=None + ): self.filename = filename - assert source == 'TEST' + assert source == "TEST" assert serial is None def run_import(self): @@ -372,248 +435,282 @@ class TestROAImportRunner: # is shared between ROAImportRunner and RPSLMirrorFullImportRunner, # not all protocols are tested here. def test_run_import_http_file_success(self, monkeypatch, config_override, tmpdir, caplog): - slurm_path = str(tmpdir) + '/slurm.json' - config_override({ - 'rpki': { - 'roa_source': 'https://host/roa.json', - 'slurm_source': 'file://' + slurm_path - } - }) + slurm_path = str(tmpdir) + "/slurm.json" + config_override( + {"rpki": {"roa_source": "https://host/roa.json", "slurm_source": "file://" + slurm_path}} + ) class MockRequestsSuccess: status_code = 200 def __init__(self, url, stream, timeout): - assert url == 'https://host/roa.json' + assert url == "https://host/roa.json" assert stream assert timeout def iter_content(self, size): - return iter([b'roa_', b'data']) + return iter([b"roa_", b"data"]) - with open(slurm_path, 'wb') as fh: - fh.write(b'slurm_data') + with open(slurm_path, "wb") as fh: + fh.write(b"slurm_data") mock_dh = Mock(spec=DatabaseHandler) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseHandler', lambda: mock_dh) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.ROADataImporter', MockROADataImporter) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseHandler", lambda: mock_dh) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.ROADataImporter", MockROADataImporter) mock_bulk_validator = Mock(spec=BulkRouteROAValidator) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.BulkRouteROAValidator', lambda dh, roas: mock_bulk_validator) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.requests.get', MockRequestsSuccess) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.notify_rpki_invalid_owners', lambda dh, invalids: 1) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.BulkRouteROAValidator", lambda dh, roas: mock_bulk_validator + ) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.requests.get", MockRequestsSuccess) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.notify_rpki_invalid_owners", lambda dh, invalids: 1 + ) mock_bulk_validator.validate_all_routes = lambda: ( - [{'rpsl_pk': 'pk_now_valid1'}, {'rpsl_pk': 'pk_now_valid2'}], - [{'rpsl_pk': 'pk_now_invalid1'}, {'rpsl_pk': 'pk_now_invalid2'}], - [{'rpsl_pk': 'pk_now_unknown1'}, {'rpsl_pk': 'pk_now_unknown2'}], + [{"rpsl_pk": "pk_now_valid1"}, {"rpsl_pk": "pk_now_valid2"}], + [{"rpsl_pk": "pk_now_invalid1"}, {"rpsl_pk": "pk_now_invalid2"}], + [{"rpsl_pk": "pk_now_unknown1"}, {"rpsl_pk": "pk_now_unknown2"}], ) ROAImportRunner().run() assert flatten_mock_calls(mock_dh) == [ - ['disable_journaling', (), {}], - ['delete_all_roa_objects', (), {}], - ['delete_all_rpsl_objects_with_journal', ('RPKI',), {'journal_guaranteed_empty': True}], - ['commit', (), {}], - ['enable_journaling', (), {}], - ['update_rpki_status', (), { - 'rpsl_objs_now_valid': [{'rpsl_pk': 'pk_now_valid1'}, {'rpsl_pk': 'pk_now_valid2'}], - 'rpsl_objs_now_invalid': [{'rpsl_pk': 'pk_now_invalid1'}, {'rpsl_pk': 'pk_now_invalid2'}], - 'rpsl_objs_now_not_found': [{'rpsl_pk': 'pk_now_unknown1'}, {'rpsl_pk': 'pk_now_unknown2'}], - }], - ['commit', (), {}], - ['close', (), {}] + ["disable_journaling", (), {}], + ["delete_all_roa_objects", (), {}], + ["delete_all_rpsl_objects_with_journal", ("RPKI",), {"journal_guaranteed_empty": True}], + ["commit", (), {}], + ["enable_journaling", (), {}], + [ + "update_rpki_status", + (), + { + "rpsl_objs_now_valid": [{"rpsl_pk": "pk_now_valid1"}, {"rpsl_pk": "pk_now_valid2"}], + "rpsl_objs_now_invalid": [{"rpsl_pk": "pk_now_invalid1"}, {"rpsl_pk": "pk_now_invalid2"}], + "rpsl_objs_now_not_found": [ + {"rpsl_pk": "pk_now_unknown1"}, + {"rpsl_pk": "pk_now_unknown2"}, + ], + }, + ], + ["commit", (), {}], + ["close", (), {}], ] - assert '2 newly valid, 2 newly invalid, 2 newly not_found routes, 1 emails sent to contacts of newly invalid authoritative objects' in caplog.text + assert ( + "2 newly valid, 2 newly invalid, 2 newly not_found routes, 1 emails sent to contacts of newly" + " invalid authoritative objects" + in caplog.text + ) def test_run_import_http_file_failed_download(self, monkeypatch, config_override, tmpdir, caplog): - config_override({ - 'rpki': { - 'roa_source': 'https://host/roa.json', + config_override( + { + "rpki": { + "roa_source": "https://host/roa.json", + } } - }) + ) class MockRequestsSuccess: status_code = 500 - content = 'expected-test-error' + content = "expected-test-error" def __init__(self, url, stream, timeout): - assert url == 'https://host/roa.json' + assert url == "https://host/roa.json" assert stream assert timeout mock_dh = Mock(spec=DatabaseHandler) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseHandler', lambda: mock_dh) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.requests.get', MockRequestsSuccess) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseHandler", lambda: mock_dh) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.requests.get", MockRequestsSuccess) ROAImportRunner().run() - assert 'Failed to download https://host/roa.json: 500: expected-test-error' in caplog.text + assert "Failed to download https://host/roa.json: 500: expected-test-error" in caplog.text def test_exception_handling(self, monkeypatch, config_override, tmpdir, caplog): - tmp_roa_source = tmpdir + '/roa.json' - with open(tmp_roa_source, 'wb') as fh: - fh.write(b'roa_data') - config_override({ - 'rpki': { - 'roa_source': 'file://' + str(tmp_roa_source), + tmp_roa_source = tmpdir + "/roa.json" + with open(tmp_roa_source, "wb") as fh: + fh.write(b"roa_data") + config_override( + { + "rpki": { + "roa_source": "file://" + str(tmp_roa_source), + } } - }) + ) mock_dh = Mock(spec=DatabaseHandler) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseHandler', lambda: mock_dh) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseHandler", lambda: mock_dh) - mock_importer = Mock(side_effect=ValueError('expected-test-error-1')) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.ROADataImporter', mock_importer) + mock_importer = Mock(side_effect=ValueError("expected-test-error-1")) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.ROADataImporter", mock_importer) ROAImportRunner().run() - mock_importer = Mock(side_effect=ROAParserException('expected-test-error-2')) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.ROADataImporter', mock_importer) + mock_importer = Mock(side_effect=ROAParserException("expected-test-error-2")) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.ROADataImporter", mock_importer) ROAImportRunner().run() assert flatten_mock_calls(mock_dh) == 2 * [ - ['disable_journaling', (), {}], - ['delete_all_roa_objects', (), {}], - ['delete_all_rpsl_objects_with_journal', ('RPKI',), {'journal_guaranteed_empty': True}], - ['close', (), {}] + ["disable_journaling", (), {}], + ["delete_all_roa_objects", (), {}], + ["delete_all_rpsl_objects_with_journal", ("RPKI",), {"journal_guaranteed_empty": True}], + ["close", (), {}], ] - assert 'expected-test-error-1' in caplog.text - assert 'expected-test-error-2' in caplog.text + assert "expected-test-error-1" in caplog.text + assert "expected-test-error-2" in caplog.text def test_file_error_handling(self, monkeypatch, config_override, tmpdir, caplog): - tmp_roa_source = tmpdir + '/roa.json' - config_override({ - 'rpki': { - 'roa_source': 'file://' + str(tmp_roa_source), + tmp_roa_source = tmpdir + "/roa.json" + config_override( + { + "rpki": { + "roa_source": "file://" + str(tmp_roa_source), + } } - }) + ) mock_dh = Mock(spec=DatabaseHandler) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseHandler', lambda: mock_dh) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseHandler", lambda: mock_dh) ROAImportRunner().run() assert flatten_mock_calls(mock_dh) == [ - ['disable_journaling', (), {}], - ['delete_all_roa_objects', (), {}], - ['delete_all_rpsl_objects_with_journal', ('RPKI',), {'journal_guaranteed_empty': True}], - ['close', (), {}] + ["disable_journaling", (), {}], + ["delete_all_roa_objects", (), {}], + ["delete_all_rpsl_objects_with_journal", ("RPKI",), {"journal_guaranteed_empty": True}], + ["close", (), {}], ] - assert 'No such file or directory' in caplog.text + assert "No such file or directory" in caplog.text class MockROADataImporter: def __init__(self, rpki_text: str, slurm_text: str, database_handler: DatabaseHandler): - assert rpki_text == 'roa_data' - assert slurm_text == 'slurm_data' - self.roa_objs = ['roa1', 'roa2'] + assert rpki_text == "roa_data" + assert slurm_text == "slurm_data" + self.roa_objs = ["roa1", "roa2"] class TestScopeFilterUpdateRunner: def test_run(self, monkeypatch, config_override, tmpdir, caplog): mock_dh = Mock(spec=DatabaseHandler) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseHandler', lambda: mock_dh) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseHandler", lambda: mock_dh) mock_scopefilter = Mock(spec=ScopeFilterValidator) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.ScopeFilterValidator', lambda: mock_scopefilter) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.ScopeFilterValidator", lambda: mock_scopefilter + ) mock_scopefilter.validate_all_rpsl_objects = lambda database_handler: ( - [{'rpsl_pk': 'pk_now_in_scope1'}, {'rpsl_pk': 'pk_now_in_scope2'}], - [{'rpsl_pk': 'pk_now_out_scope_as1'}, {'rpsl_pk': 'pk_now_out_scope_as2'}], - [{'rpsl_pk': 'pk_now_out_scope_prefix1'}, {'rpsl_pk': 'pk_now_out_scope_prefix2'}], + [{"rpsl_pk": "pk_now_in_scope1"}, {"rpsl_pk": "pk_now_in_scope2"}], + [{"rpsl_pk": "pk_now_out_scope_as1"}, {"rpsl_pk": "pk_now_out_scope_as2"}], + [{"rpsl_pk": "pk_now_out_scope_prefix1"}, {"rpsl_pk": "pk_now_out_scope_prefix2"}], ) ScopeFilterUpdateRunner().run() assert flatten_mock_calls(mock_dh) == [ - ['update_scopefilter_status', (), { - 'rpsl_objs_now_in_scope': [{'rpsl_pk': 'pk_now_in_scope1'}, {'rpsl_pk': 'pk_now_in_scope2'}], - 'rpsl_objs_now_out_scope_as': [{'rpsl_pk': 'pk_now_out_scope_as1'}, {'rpsl_pk': 'pk_now_out_scope_as2'}], - 'rpsl_objs_now_out_scope_prefix': [{'rpsl_pk': 'pk_now_out_scope_prefix1'}, {'rpsl_pk': 'pk_now_out_scope_prefix2'}], - }], - ['commit', (), {}], - ['close', (), {}] + [ + "update_scopefilter_status", + (), + { + "rpsl_objs_now_in_scope": [ + {"rpsl_pk": "pk_now_in_scope1"}, + {"rpsl_pk": "pk_now_in_scope2"}, + ], + "rpsl_objs_now_out_scope_as": [ + {"rpsl_pk": "pk_now_out_scope_as1"}, + {"rpsl_pk": "pk_now_out_scope_as2"}, + ], + "rpsl_objs_now_out_scope_prefix": [ + {"rpsl_pk": "pk_now_out_scope_prefix1"}, + {"rpsl_pk": "pk_now_out_scope_prefix2"}, + ], + }, + ], + ["commit", (), {}], + ["close", (), {}], ] - assert '2 newly in scope, 2 newly out of scope AS, 2 newly out of scope prefix' in caplog.text + assert "2 newly in scope, 2 newly out of scope AS, 2 newly out of scope prefix" in caplog.text def test_exception_handling(self, monkeypatch, config_override, tmpdir, caplog): mock_dh = Mock(spec=DatabaseHandler) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseHandler', lambda: mock_dh) - mock_scopefilter = Mock(side_effect=ValueError('expected-test-error')) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.ScopeFilterValidator', mock_scopefilter) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseHandler", lambda: mock_dh) + mock_scopefilter = Mock(side_effect=ValueError("expected-test-error")) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.ScopeFilterValidator", mock_scopefilter) ScopeFilterUpdateRunner().run() - assert flatten_mock_calls(mock_dh) == [ - ['close', (), {}] - ] - assert 'expected-test-error' in caplog.text + assert flatten_mock_calls(mock_dh) == [["close", (), {}]] + assert "expected-test-error" in caplog.text class TestRoutePreferenceUpdateRunner: def test_run(self, monkeypatch, config_override, tmpdir, caplog): mock_dh = Mock(spec=DatabaseHandler) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseHandler', lambda: mock_dh) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseHandler", lambda: mock_dh) mock_update_function = Mock(spec=update_route_preference_status) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.update_route_preference_status', mock_update_function) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.update_route_preference_status", mock_update_function + ) RoutePreferenceUpdateRunner().run() - assert flatten_mock_calls(mock_dh) == [ - ['commit', (), {}], - ['close', (), {}] - ] + assert flatten_mock_calls(mock_dh) == [["commit", (), {}], ["close", (), {}]] def test_exception_handling(self, monkeypatch, config_override, tmpdir, caplog): mock_dh = Mock(spec=DatabaseHandler) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.DatabaseHandler', lambda: mock_dh) - mock_update_function = Mock(side_effect=ValueError('expected-test-error')) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.update_route_preference_status', mock_update_function) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.DatabaseHandler", lambda: mock_dh) + mock_update_function = Mock(side_effect=ValueError("expected-test-error")) + monkeypatch.setattr( + "irrd.mirroring.mirror_runners_import.update_route_preference_status", mock_update_function + ) RoutePreferenceUpdateRunner().run() - assert flatten_mock_calls(mock_dh) == [ - ['close', (), {}] - ] - assert 'expected-test-error' in caplog.text + assert flatten_mock_calls(mock_dh) == [["close", (), {}]] + assert "expected-test-error" in caplog.text class TestNRTMImportUpdateStreamRunner: def test_run_import(self, monkeypatch, config_override): - config_override({ - 'sources': { - 'TEST': { - 'nrtm_host': '192.0.2.1', - 'nrtm_port': 43, + config_override( + { + "sources": { + "TEST": { + "nrtm_host": "192.0.2.1", + "nrtm_port": 43, + } } } - }) + ) def mock_whois_query(host, port, query, end_markings) -> str: - assert host == '192.0.2.1' + assert host == "192.0.2.1" assert port == 43 - assert query == '-g TEST:3:424243-LAST' - assert 'TEST' in end_markings[0] - return 'response' + assert query == "-g TEST:3:424243-LAST" + assert "TEST" in end_markings[0] + return "response" mock_dh = Mock() - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.NRTMStreamParser', MockNRTMStreamParser) - monkeypatch.setattr('irrd.mirroring.mirror_runners_import.whois_query', mock_whois_query) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.NRTMStreamParser", MockNRTMStreamParser) + monkeypatch.setattr("irrd.mirroring.mirror_runners_import.whois_query", mock_whois_query) - NRTMImportUpdateStreamRunner('TEST').run(424242, mock_dh) + NRTMImportUpdateStreamRunner("TEST").run(424242, mock_dh) def test_missing_source_settings(self, monkeypatch, config_override): - config_override({ - 'sources': { - 'TEST': { - 'nrtm_port': '4343', + config_override( + { + "sources": { + "TEST": { + "nrtm_port": "4343", + } } } - }) + ) mock_dh = Mock() - NRTMImportUpdateStreamRunner('TEST').run(424242, mock_dh) + NRTMImportUpdateStreamRunner("TEST").run(424242, mock_dh) class MockNRTMStreamParser: def __init__(self, source, response, database_handler): - assert source == 'TEST' - assert response == 'response' + assert source == "TEST" + assert response == "response" self.operations = [Mock()] diff --git a/irrd/mirroring/tests/test_nrtm_generator.py b/irrd/mirroring/tests/test_nrtm_generator.py index fc2af5eba..a65ec4eb9 100644 --- a/irrd/mirroring/tests/test_nrtm_generator.py +++ b/irrd/mirroring/tests/test_nrtm_generator.py @@ -1,48 +1,52 @@ -# flake8: noqa: W291, W293 +import textwrap from itertools import cycle, repeat +from unittest.mock import Mock import pytest -import textwrap -from unittest.mock import Mock from irrd.storage.models import DatabaseOperation + from ..nrtm_generator import NRTMGenerator, NRTMGeneratorException @pytest.fixture() def prepare_generator(monkeypatch, config_override): - config_override({ - 'sources': { - 'TEST': { - 'keep_journal': True, - 'nrtm_query_serial_range_limit': 200, + config_override( + { + "sources": { + "TEST": { + "keep_journal": True, + "nrtm_query_serial_range_limit": 200, + } } } - }) + ) mock_dh = Mock() mock_djq = Mock() mock_dsq = Mock() - monkeypatch.setattr('irrd.mirroring.nrtm_generator.RPSLDatabaseJournalQuery', lambda: mock_djq) - monkeypatch.setattr('irrd.mirroring.nrtm_generator.DatabaseStatusQuery', lambda: mock_dsq) + monkeypatch.setattr("irrd.mirroring.nrtm_generator.RPSLDatabaseJournalQuery", lambda: mock_djq) + monkeypatch.setattr("irrd.mirroring.nrtm_generator.DatabaseStatusQuery", lambda: mock_dsq) - responses = cycle([ - repeat({'serial_oldest_journal': 100, 'serial_newest_journal': 200}), + responses = cycle( [ - { - # The CRYPT-PW hash must not appear in the output - 'object_text': 'object 1 🦄\nauth: CRYPT-PW foobar\n', - 'operation': DatabaseOperation.add_or_update, - 'serial_nrtm': 120, - }, - { - 'object_text': 'object 2 🌈\n', - 'operation': DatabaseOperation.delete, - 'serial_nrtm': 180, - }, - ], - ]) + repeat({"serial_oldest_journal": 100, "serial_newest_journal": 200}), + [ + { + # The CRYPT-PW hash must not appear in the output + "object_text": "object 1 🦄\nauth: CRYPT-PW foobar\n", + "operation": DatabaseOperation.add_or_update, + "serial_nrtm": 120, + }, + { + "object_text": "object 2 🌈\n", + "operation": DatabaseOperation.delete, + "serial_nrtm": 180, + }, + ], + ] + ) mock_dh.execute_query = lambda q: next(responses) yield NRTMGenerator(), mock_dh @@ -51,9 +55,12 @@ def prepare_generator(monkeypatch, config_override): class TestNRTMGenerator: def test_generate_serial_range_v3(self, prepare_generator): generator, mock_dh = prepare_generator - result = generator.generate('TEST', '3', 110, 190, mock_dh) + result = generator.generate("TEST", "3", 110, 190, mock_dh) - assert result == textwrap.dedent(""" + assert ( + result + == textwrap.dedent( + """ %START Version: 3 TEST 110-190 ADD 120 @@ -65,13 +72,18 @@ def test_generate_serial_range_v3(self, prepare_generator): object 2 🌈 - %END TEST""").strip() + %END TEST""" + ).strip() + ) def test_generate_serial_range_v1(self, prepare_generator): generator, mock_dh = prepare_generator - result = generator.generate('TEST', '1', 110, 190, mock_dh) + result = generator.generate("TEST", "1", 110, 190, mock_dh) - assert result == textwrap.dedent(""" + assert ( + result + == textwrap.dedent( + """ %START Version: 1 TEST 110-190 ADD @@ -83,13 +95,18 @@ def test_generate_serial_range_v1(self, prepare_generator): object 2 🌈 - %END TEST""").strip() + %END TEST""" + ).strip() + ) def test_generate_until_last(self, prepare_generator, config_override): generator, mock_dh = prepare_generator - result = generator.generate('TEST', '3', 110, None, mock_dh) + result = generator.generate("TEST", "3", 110, None, mock_dh) - assert result == textwrap.dedent(""" + assert ( + result + == textwrap.dedent( + """ %START Version: 3 TEST 110-200 ADD 120 @@ -101,88 +118,100 @@ def test_generate_until_last(self, prepare_generator, config_override): object 2 🌈 - %END TEST""").strip() + %END TEST""" + ).strip() + ) def test_serial_range_start_higher_than_low(self, prepare_generator): generator, mock_dh = prepare_generator with pytest.raises(NRTMGeneratorException) as nge: - generator.generate('TEST', '3', 200, 190, mock_dh) - assert 'Start of the serial range (200) must be lower or equal to end of the serial range (190)' in str(nge.value) + generator.generate("TEST", "3", 200, 190, mock_dh) + assert ( + "Start of the serial range (200) must be lower or equal to end of the serial range (190)" + in str(nge.value) + ) def test_serial_start_too_low(self, prepare_generator): generator, mock_dh = prepare_generator with pytest.raises(NRTMGeneratorException) as nge: - generator.generate('TEST', '3', 10, 190, mock_dh) - assert 'Serials 10 - 100 do not exist' in str(nge.value) + generator.generate("TEST", "3", 10, 190, mock_dh) + assert "Serials 10 - 100 do not exist" in str(nge.value) def test_serial_start_too_high(self, prepare_generator): generator, mock_dh = prepare_generator with pytest.raises(NRTMGeneratorException) as nge: - generator.generate('TEST', '3', 202, None, mock_dh) - assert 'Serials 200 - 202 do not exist' in str(nge.value) + generator.generate("TEST", "3", 202, None, mock_dh) + assert "Serials 200 - 202 do not exist" in str(nge.value) def test_serial_end_too_high(self, prepare_generator): generator, mock_dh = prepare_generator with pytest.raises(NRTMGeneratorException) as nge: - generator.generate('TEST', '3', 110, 300, mock_dh) - assert 'Serials 200 - 300 do not exist' in str(nge.value) + generator.generate("TEST", "3", 110, 300, mock_dh) + assert "Serials 200 - 300 do not exist" in str(nge.value) def test_no_new_updates(self, prepare_generator): # This message is only triggered when starting from a serial # that is the current plus one, until LAST generator, mock_dh = prepare_generator - result = generator.generate('TEST', '3', 201, None, mock_dh) - assert result == '% Warning: there are no newer updates available' + result = generator.generate("TEST", "3", 201, None, mock_dh) + assert result == "% Warning: there are no newer updates available" def test_no_updates(self, prepare_generator): generator, mock_dh = prepare_generator - responses = repeat({'serial_oldest_journal': None, 'serial_newest_journal': None}) + responses = repeat({"serial_oldest_journal": None, "serial_newest_journal": None}) mock_dh.execute_query = lambda q: responses - result = generator.generate('TEST', '3', 201, None, mock_dh) - assert result == '% Warning: there are no updates available' + result = generator.generate("TEST", "3", 201, None, mock_dh) + assert result == "% Warning: there are no updates available" def test_no_journal_kept(self, prepare_generator, config_override): generator, mock_dh = prepare_generator - config_override({ - 'sources': { - 'TEST': { - 'keep_journal': False, + config_override( + { + "sources": { + "TEST": { + "keep_journal": False, + } } } - }) + ) with pytest.raises(NRTMGeneratorException) as nge: - generator.generate('TEST', '3', 110, 300, mock_dh) - assert 'No journal kept for this source, unable to serve NRTM queries' in str(nge.value) + generator.generate("TEST", "3", 110, 300, mock_dh) + assert "No journal kept for this source, unable to serve NRTM queries" in str(nge.value) def test_no_source_status_entry(self, prepare_generator, config_override): generator, mock_dh = prepare_generator mock_dh.execute_query = Mock(side_effect=StopIteration()) with pytest.raises(NRTMGeneratorException) as nge: - generator.generate('TEST', '3', 110, 300, mock_dh) - assert 'There are no journal entries for this source.' in str(nge.value) + generator.generate("TEST", "3", 110, 300, mock_dh) + assert "There are no journal entries for this source." in str(nge.value) def test_v3_range_limit_not_set(self, prepare_generator, config_override): generator, mock_dh = prepare_generator - config_override({ - 'sources': { - 'TEST': { - 'keep_journal': True, + config_override( + { + "sources": { + "TEST": { + "keep_journal": True, + } } } - }) + ) - result = generator.generate('TEST', '3', 110, 190, mock_dh) + result = generator.generate("TEST", "3", 110, 190, mock_dh) - assert result == textwrap.dedent(""" + assert ( + result + == textwrap.dedent( + """ %START Version: 3 TEST 110-190 ADD 120 @@ -194,28 +223,35 @@ def test_v3_range_limit_not_set(self, prepare_generator, config_override): object 2 🌈 - %END TEST""").strip() + %END TEST""" + ).strip() + ) def test_range_limit_exceeded(self, prepare_generator, config_override): generator, mock_dh = prepare_generator - config_override({ - 'sources': { - 'TEST': { - 'keep_journal': True, - 'nrtm_query_serial_range_limit': 50, + config_override( + { + "sources": { + "TEST": { + "keep_journal": True, + "nrtm_query_serial_range_limit": 50, + } } } - }) + ) with pytest.raises(NRTMGeneratorException) as nge: - generator.generate('TEST', '3', 110, 190, mock_dh) - assert 'Serial range requested exceeds maximum range of 50' in str(nge.value) + generator.generate("TEST", "3", 110, 190, mock_dh) + assert "Serial range requested exceeds maximum range of 50" in str(nge.value) def test_include_auth_hash(self, prepare_generator): generator, mock_dh = prepare_generator - result = generator.generate('TEST', '3', 110, 190, mock_dh, False) + result = generator.generate("TEST", "3", 110, 190, mock_dh, False) - assert result == textwrap.dedent(""" + assert ( + result + == textwrap.dedent( + """ %START Version: 3 TEST 110-190 ADD 120 @@ -227,4 +263,6 @@ def test_include_auth_hash(self, prepare_generator): object 2 🌈 - %END TEST""").strip() + %END TEST""" + ).strip() + ) diff --git a/irrd/mirroring/tests/test_nrtm_operation.py b/irrd/mirroring/tests/test_nrtm_operation.py index 59c247000..60579d7fd 100644 --- a/irrd/mirroring/tests/test_nrtm_operation.py +++ b/irrd/mirroring/tests/test_nrtm_operation.py @@ -5,33 +5,37 @@ from irrd.scopefilter.status import ScopeFilterStatus from irrd.scopefilter.validators import ScopeFilterValidator from irrd.storage.models import DatabaseOperation, JournalEntryOrigin -from irrd.utils.rpsl_samples import (SAMPLE_MNTNER, SAMPLE_UNKNOWN_CLASS, - SAMPLE_MALFORMED_EMPTY_LINE, SAMPLE_KEY_CERT, - KEY_CERT_SIGNED_MESSAGE_VALID, SAMPLE_ROUTE) +from irrd.utils.rpsl_samples import ( + KEY_CERT_SIGNED_MESSAGE_VALID, + SAMPLE_KEY_CERT, + SAMPLE_MALFORMED_EMPTY_LINE, + SAMPLE_MNTNER, + SAMPLE_ROUTE, + SAMPLE_UNKNOWN_CLASS, +) + from ..nrtm_operation import NRTMOperation class TestNRTMOperation: - def test_nrtm_add_valid_without_strict_import_keycert(self, monkeypatch, tmp_gpg_dir): mock_dh = Mock() mock_scopefilter = Mock(spec=ScopeFilterValidator) - monkeypatch.setattr('irrd.mirroring.nrtm_operation.ScopeFilterValidator', - lambda: mock_scopefilter) - mock_scopefilter.validate_rpsl_object = lambda obj: (ScopeFilterStatus.in_scope, '') + monkeypatch.setattr("irrd.mirroring.nrtm_operation.ScopeFilterValidator", lambda: mock_scopefilter) + mock_scopefilter.validate_rpsl_object = lambda obj: (ScopeFilterStatus.in_scope, "") operation = NRTMOperation( - source='TEST', + source="TEST", operation=DatabaseOperation.add_or_update, serial=42424242, object_text=SAMPLE_KEY_CERT, strict_validation_key_cert=False, - object_class_filter=['route', 'route6', 'mntner', 'key-cert'], + object_class_filter=["route", "route6", "mntner", "key-cert"], ) assert operation.save(database_handler=mock_dh) assert mock_dh.upsert_rpsl_object.call_count == 1 - assert mock_dh.mock_calls[0][1][0].pk() == 'PGPKEY-80F238C6' + assert mock_dh.mock_calls[0][1][0].pk() == "PGPKEY-80F238C6" assert mock_dh.mock_calls[0][1][1] == JournalEntryOrigin.mirror # key-cert should not be imported in the keychain, therefore @@ -42,22 +46,21 @@ def test_nrtm_add_valid_without_strict_import_keycert(self, monkeypatch, tmp_gpg def test_nrtm_add_valid_with_strict_import_keycert(self, monkeypatch, tmp_gpg_dir): mock_dh = Mock() mock_scopefilter = Mock(spec=ScopeFilterValidator) - monkeypatch.setattr('irrd.mirroring.nrtm_operation.ScopeFilterValidator', - lambda: mock_scopefilter) - mock_scopefilter.validate_rpsl_object = lambda obj: (ScopeFilterStatus.in_scope, '') + monkeypatch.setattr("irrd.mirroring.nrtm_operation.ScopeFilterValidator", lambda: mock_scopefilter) + mock_scopefilter.validate_rpsl_object = lambda obj: (ScopeFilterStatus.in_scope, "") operation = NRTMOperation( - source='TEST', + source="TEST", operation=DatabaseOperation.add_or_update, serial=42424242, object_text=SAMPLE_KEY_CERT, strict_validation_key_cert=True, - object_class_filter=['route', 'route6', 'mntner', 'key-cert'], + object_class_filter=["route", "route6", "mntner", "key-cert"], ) assert operation.save(database_handler=mock_dh) assert mock_dh.upsert_rpsl_object.call_count == 1 - assert mock_dh.mock_calls[0][1][0].pk() == 'PGPKEY-80F238C6' + assert mock_dh.mock_calls[0][1][0].pk() == "PGPKEY-80F238C6" assert mock_dh.mock_calls[0][1][1] == JournalEntryOrigin.mirror # key-cert should be imported in the keychain, therefore @@ -68,16 +71,16 @@ def test_nrtm_add_valid_with_strict_import_keycert(self, monkeypatch, tmp_gpg_di def test_nrtm_add_valid_rpki_scopefilter_aware(self, tmp_gpg_dir, monkeypatch): mock_dh = Mock() mock_route_validator = Mock() - monkeypatch.setattr('irrd.mirroring.nrtm_operation.SingleRouteROAValidator', - lambda dh: mock_route_validator) + monkeypatch.setattr( + "irrd.mirroring.nrtm_operation.SingleRouteROAValidator", lambda dh: mock_route_validator + ) mock_scopefilter = Mock(spec=ScopeFilterValidator) - monkeypatch.setattr('irrd.mirroring.nrtm_operation.ScopeFilterValidator', - lambda: mock_scopefilter) + monkeypatch.setattr("irrd.mirroring.nrtm_operation.ScopeFilterValidator", lambda: mock_scopefilter) mock_route_validator.validate_route = lambda prefix, asn, source: RPKIStatus.invalid - mock_scopefilter.validate_rpsl_object = lambda obj: (ScopeFilterStatus.out_scope_prefix, '') + mock_scopefilter.validate_rpsl_object = lambda obj: (ScopeFilterStatus.out_scope_prefix, "") operation = NRTMOperation( - source='TEST', + source="TEST", operation=DatabaseOperation.add_or_update, serial=42424242, object_text=SAMPLE_ROUTE, @@ -87,7 +90,7 @@ def test_nrtm_add_valid_rpki_scopefilter_aware(self, tmp_gpg_dir, monkeypatch): assert operation.save(database_handler=mock_dh) assert mock_dh.upsert_rpsl_object.call_count == 1 - assert mock_dh.mock_calls[0][1][0].pk() == '192.0.2.0/24AS65537' + assert mock_dh.mock_calls[0][1][0].pk() == "192.0.2.0/24AS65537" assert mock_dh.mock_calls[0][1][0].rpki_status == RPKIStatus.invalid assert mock_dh.mock_calls[0][1][0].scopefilter_status == ScopeFilterStatus.out_scope_prefix assert mock_dh.mock_calls[0][1][1] == JournalEntryOrigin.mirror @@ -96,12 +99,12 @@ def test_nrtm_add_valid_ignored_object_class(self): mock_dh = Mock() operation = NRTMOperation( - source='TEST', + source="TEST", operation=DatabaseOperation.add_or_update, serial=42424242, object_text=SAMPLE_MNTNER, strict_validation_key_cert=False, - object_class_filter=['route', 'route6'], + object_class_filter=["route", "route6"], ) assert not operation.save(database_handler=mock_dh) assert mock_dh.upsert_rpsl_object.call_count == 0 @@ -110,7 +113,7 @@ def test_nrtm_delete_valid(self): mock_dh = Mock() operation = NRTMOperation( - source='TEST', + source="TEST", operation=DatabaseOperation.delete, serial=42424242, strict_validation_key_cert=False, @@ -119,14 +122,14 @@ def test_nrtm_delete_valid(self): assert operation.save(database_handler=mock_dh) assert mock_dh.delete_rpsl_object.call_count == 1 - assert mock_dh.mock_calls[0][2]['rpsl_object'].pk() == 'TEST-MNT' - assert mock_dh.mock_calls[0][2]['origin'] == JournalEntryOrigin.mirror + assert mock_dh.mock_calls[0][2]["rpsl_object"].pk() == "TEST-MNT" + assert mock_dh.mock_calls[0][2]["origin"] == JournalEntryOrigin.mirror def test_nrtm_add_invalid_unknown_object_class(self): mock_dh = Mock() operation = NRTMOperation( - source='TEST', + source="TEST", operation=DatabaseOperation.add_or_update, serial=42424242, strict_validation_key_cert=False, @@ -139,7 +142,7 @@ def test_nrtm_add_invalid_inconsistent_source(self): mock_dh = Mock() operation = NRTMOperation( - source='NOT-TEST', + source="NOT-TEST", operation=DatabaseOperation.add_or_update, serial=42424242, strict_validation_key_cert=False, @@ -152,7 +155,7 @@ def test_nrtm_add_invalid_rpsl_errors(self): mock_dh = Mock() operation = NRTMOperation( - source='TEST', + source="TEST", operation=DatabaseOperation.add_or_update, serial=42424242, strict_validation_key_cert=False, @@ -166,11 +169,11 @@ def test_nrtm_delete_valid_incomplete_object(self): # a source attribute. However, as the source of the NRTM # stream is known, we can guess this. # This is accepted for deletions only. - obj_text = 'route: 192.0.02.0/24\norigin: AS65537' + obj_text = "route: 192.0.02.0/24\norigin: AS65537" mock_dh = Mock() operation = NRTMOperation( - source='TEST', + source="TEST", operation=DatabaseOperation.delete, serial=42424242, object_text=obj_text, @@ -179,17 +182,17 @@ def test_nrtm_delete_valid_incomplete_object(self): assert operation.save(database_handler=mock_dh) assert mock_dh.delete_rpsl_object.call_count == 1 - assert mock_dh.mock_calls[0][2]['rpsl_object'].pk() == '192.0.2.0/24AS65537' - assert mock_dh.mock_calls[0][2]['rpsl_object'].source() == 'TEST' - assert mock_dh.mock_calls[0][2]['origin'] == JournalEntryOrigin.mirror + assert mock_dh.mock_calls[0][2]["rpsl_object"].pk() == "192.0.2.0/24AS65537" + assert mock_dh.mock_calls[0][2]["rpsl_object"].source() == "TEST" + assert mock_dh.mock_calls[0][2]["origin"] == JournalEntryOrigin.mirror def test_nrtm_add_invalid_incomplete_object(self): # Source-less objects are not accepted for add/update - obj_text = 'route: 192.0.02.0/24\norigin: AS65537' + obj_text = "route: 192.0.02.0/24\norigin: AS65537" mock_dh = Mock() operation = NRTMOperation( - source='TEST', + source="TEST", operation=DatabaseOperation.add_or_update, serial=42424242, object_text=obj_text, diff --git a/irrd/mirroring/tests/test_parsers.py b/irrd/mirroring/tests/test_parsers.py index 07e320137..d29127521 100644 --- a/irrd/mirroring/tests/test_parsers.py +++ b/irrd/mirroring/tests/test_parsers.py @@ -10,37 +10,58 @@ from irrd.scopefilter.validators import ScopeFilterValidator from irrd.storage.models import DatabaseOperation, JournalEntryOrigin from irrd.utils.rpsl_samples import ( - SAMPLE_ROUTE, SAMPLE_UNKNOWN_CLASS, SAMPLE_UNKNOWN_ATTRIBUTE, SAMPLE_MALFORMED_PK, - SAMPLE_ROUTE6, SAMPLE_KEY_CERT, KEY_CERT_SIGNED_MESSAGE_VALID, SAMPLE_LEGACY_IRRD_ARTIFACT, - SAMPLE_ROLE, SAMPLE_RTR_SET) + KEY_CERT_SIGNED_MESSAGE_VALID, + SAMPLE_KEY_CERT, + SAMPLE_LEGACY_IRRD_ARTIFACT, + SAMPLE_MALFORMED_PK, + SAMPLE_ROLE, + SAMPLE_ROUTE, + SAMPLE_ROUTE6, + SAMPLE_RTR_SET, + SAMPLE_UNKNOWN_ATTRIBUTE, + SAMPLE_UNKNOWN_CLASS, +) from irrd.utils.test_utils import flatten_mock_calls -from .nrtm_samples import (SAMPLE_NRTM_V3, SAMPLE_NRTM_V1, SAMPLE_NRTM_V1_TOO_MANY_ITEMS, - SAMPLE_NRTM_INVALID_VERSION, SAMPLE_NRTM_V3_NO_END, - SAMPLE_NRTM_V3_SERIAL_GAP, SAMPLE_NRTM_V3_INVALID_MULTIPLE_START_LINES, - SAMPLE_NRTM_INVALID_NO_START_LINE, SAMPLE_NRTM_V3_SERIAL_OUT_OF_ORDER) -from ..parsers import NRTMStreamParser, MirrorFileImportParser, MirrorUpdateFileImportParser + +from ..parsers import ( + MirrorFileImportParser, + MirrorUpdateFileImportParser, + NRTMStreamParser, +) +from .nrtm_samples import ( + SAMPLE_NRTM_INVALID_NO_START_LINE, + SAMPLE_NRTM_INVALID_VERSION, + SAMPLE_NRTM_V1, + SAMPLE_NRTM_V1_TOO_MANY_ITEMS, + SAMPLE_NRTM_V3, + SAMPLE_NRTM_V3_INVALID_MULTIPLE_START_LINES, + SAMPLE_NRTM_V3_NO_END, + SAMPLE_NRTM_V3_SERIAL_GAP, + SAMPLE_NRTM_V3_SERIAL_OUT_OF_ORDER, +) @pytest.fixture def mock_scopefilter(monkeypatch): mock_scopefilter = Mock(spec=ScopeFilterValidator) - monkeypatch.setattr('irrd.mirroring.parsers.ScopeFilterValidator', - lambda: mock_scopefilter) - mock_scopefilter.validate_rpsl_object = lambda obj: (ScopeFilterStatus.in_scope, '') + monkeypatch.setattr("irrd.mirroring.parsers.ScopeFilterValidator", lambda: mock_scopefilter) + mock_scopefilter.validate_rpsl_object = lambda obj: (ScopeFilterStatus.in_scope, "") return mock_scopefilter class TestMirrorFileImportParser: # This test also covers the common parts of MirrorFileImportParserBase def test_parse(self, mock_scopefilter, caplog, tmp_gpg_dir, config_override): - config_override({ - 'sources': { - 'TEST': { - 'object_class_filter': ['route', 'key-cert'], - 'strict_import_keycert_objects': True, + config_override( + { + "sources": { + "TEST": { + "object_class_filter": ["route", "key-cert"], + "strict_import_keycert_objects": True, + } } } - }) + ) mock_dh = Mock() mock_roa_validator = Mock(spec=BulkRouteROAValidator) mock_roa_validator.validate_route = lambda ip, length, asn, source: RPKIStatus.invalid @@ -49,18 +70,18 @@ def test_parse(self, mock_scopefilter, caplog, tmp_gpg_dir, config_override): SAMPLE_UNKNOWN_ATTRIBUTE, # valid, because mirror imports are non-strict SAMPLE_ROUTE6, # Valid, excluded by object class filter SAMPLE_KEY_CERT, - SAMPLE_ROUTE.replace('TEST', 'BADSOURCE'), + SAMPLE_ROUTE.replace("TEST", "BADSOURCE"), SAMPLE_UNKNOWN_CLASS, SAMPLE_MALFORMED_PK, SAMPLE_LEGACY_IRRD_ARTIFACT, ] - test_input = '\n\n'.join(test_data) + test_input = "\n\n".join(test_data) with tempfile.NamedTemporaryFile() as fp: - fp.write(test_input.encode('utf-8')) + fp.write(test_input.encode("utf-8")) fp.seek(0) parser = MirrorFileImportParser( - source='TEST', + source="TEST", filename=fp.name, serial=424242, database_handler=mock_dh, @@ -68,210 +89,220 @@ def test_parse(self, mock_scopefilter, caplog, tmp_gpg_dir, config_override): ) parser.run_import() assert len(mock_dh.mock_calls) == 5 - assert mock_dh.mock_calls[0][0] == 'upsert_rpsl_object' - assert mock_dh.mock_calls[0][1][0].pk() == '192.0.2.0/24AS65537' + assert mock_dh.mock_calls[0][0] == "upsert_rpsl_object" + assert mock_dh.mock_calls[0][1][0].pk() == "192.0.2.0/24AS65537" assert mock_dh.mock_calls[0][1][0].rpki_status == RPKIStatus.invalid assert mock_dh.mock_calls[0][1][0].scopefilter_status == ScopeFilterStatus.in_scope - assert mock_dh.mock_calls[1][0] == 'upsert_rpsl_object' - assert mock_dh.mock_calls[1][1][0].pk() == 'PGPKEY-80F238C6' - assert mock_dh.mock_calls[2][0] == 'record_mirror_error' - assert mock_dh.mock_calls[3][0] == 'record_mirror_error' - assert mock_dh.mock_calls[4][0] == 'record_serial_seen' - assert mock_dh.mock_calls[4][1][0] == 'TEST' + assert mock_dh.mock_calls[1][0] == "upsert_rpsl_object" + assert mock_dh.mock_calls[1][1][0].pk() == "PGPKEY-80F238C6" + assert mock_dh.mock_calls[2][0] == "record_mirror_error" + assert mock_dh.mock_calls[3][0] == "record_mirror_error" + assert mock_dh.mock_calls[4][0] == "record_serial_seen" + assert mock_dh.mock_calls[4][1][0] == "TEST" assert mock_dh.mock_calls[4][1][1] == 424242 - assert 'Invalid source BADSOURCE for object' in caplog.text - assert 'Invalid address prefix' in caplog.text - assert 'File import for TEST: 6 objects read, 2 objects inserted, ignored 2 due to errors' in caplog.text - assert 'ignored 1 due to object_class_filter' in caplog.text - assert 'Ignored 1 objects found in file import for TEST due to unknown object classes' in caplog.text + assert "Invalid source BADSOURCE for object" in caplog.text + assert "Invalid address prefix" in caplog.text + assert ( + "File import for TEST: 6 objects read, 2 objects inserted, ignored 2 due to errors" in caplog.text + ) + assert "ignored 1 due to object_class_filter" in caplog.text + assert "Ignored 1 objects found in file import for TEST due to unknown object classes" in caplog.text key_cert_obj = rpsl_object_from_text(SAMPLE_KEY_CERT, strict_validation=False) assert key_cert_obj.verify(KEY_CERT_SIGNED_MESSAGE_VALID) def test_direct_error_return_invalid_source(self, mock_scopefilter, caplog, tmp_gpg_dir, config_override): - config_override({ - 'sources': { - 'TEST': {}, + config_override( + { + "sources": { + "TEST": {}, + } } - }) + ) mock_dh = Mock() test_data = [ SAMPLE_UNKNOWN_ATTRIBUTE, # valid, because mirror imports are non-strict - SAMPLE_ROUTE.replace('TEST', 'BADSOURCE'), + SAMPLE_ROUTE.replace("TEST", "BADSOURCE"), ] - test_input = '\n\n'.join(test_data) + test_input = "\n\n".join(test_data) with tempfile.NamedTemporaryFile() as fp: - fp.write(test_input.encode('utf-8')) + fp.write(test_input.encode("utf-8")) fp.seek(0) parser = MirrorFileImportParser( - source='TEST', + source="TEST", filename=fp.name, serial=424242, database_handler=mock_dh, direct_error_return=True, ) error = parser.run_import() - assert error == 'Invalid source BADSOURCE for object 192.0.2.0/24AS65537, expected TEST' + assert error == "Invalid source BADSOURCE for object 192.0.2.0/24AS65537, expected TEST" assert len(mock_dh.mock_calls) == 1 - assert mock_dh.mock_calls[0][0] == 'upsert_rpsl_object' - assert mock_dh.mock_calls[0][1][0].pk() == '192.0.2.0/24AS65537' + assert mock_dh.mock_calls[0][0] == "upsert_rpsl_object" + assert mock_dh.mock_calls[0][1][0].pk() == "192.0.2.0/24AS65537" assert mock_dh.mock_calls[0][1][0].rpki_status == RPKIStatus.not_found - assert 'Invalid source BADSOURCE for object' not in caplog.text - assert 'File import for TEST' not in caplog.text + assert "Invalid source BADSOURCE for object" not in caplog.text + assert "File import for TEST" not in caplog.text def test_direct_error_return_malformed_pk(self, mock_scopefilter, caplog, tmp_gpg_dir, config_override): - config_override({ - 'sources': { - 'TEST': {}, + config_override( + { + "sources": { + "TEST": {}, + } } - }) + ) mock_dh = Mock() with tempfile.NamedTemporaryFile() as fp: - fp.write(SAMPLE_MALFORMED_PK.encode('utf-8')) + fp.write(SAMPLE_MALFORMED_PK.encode("utf-8")) fp.seek(0) parser = MirrorFileImportParser( - source='TEST', + source="TEST", filename=fp.name, serial=424242, database_handler=mock_dh, direct_error_return=True, ) error = parser.run_import() - assert 'Invalid address prefix: not-a-prefix' in error + assert "Invalid address prefix: not-a-prefix" in error assert not len(mock_dh.mock_calls) - assert 'Invalid address prefix: not-a-prefix' not in caplog.text - assert 'File import for TEST' not in caplog.text + assert "Invalid address prefix: not-a-prefix" not in caplog.text + assert "File import for TEST" not in caplog.text def test_direct_error_return_unknown_class(self, mock_scopefilter, caplog, tmp_gpg_dir, config_override): - config_override({ - 'sources': { - 'TEST': {}, + config_override( + { + "sources": { + "TEST": {}, + } } - }) + ) mock_dh = Mock() with tempfile.NamedTemporaryFile() as fp: - fp.write(SAMPLE_UNKNOWN_CLASS.encode('utf-8')) + fp.write(SAMPLE_UNKNOWN_CLASS.encode("utf-8")) fp.seek(0) parser = MirrorFileImportParser( - source='TEST', + source="TEST", filename=fp.name, serial=424242, database_handler=mock_dh, direct_error_return=True, ) error = parser.run_import() - assert error == 'Unknown object class: foo-block' + assert error == "Unknown object class: foo-block" assert not len(mock_dh.mock_calls) - assert 'Unknown object class: foo-block' not in caplog.text - assert 'File import for TEST' not in caplog.text + assert "Unknown object class: foo-block" not in caplog.text + assert "File import for TEST" not in caplog.text class TestMirrorUpdateFileImportParser: def test_parse(self, mock_scopefilter, caplog, config_override): - config_override({ - 'sources': { - 'TEST': { - 'object_class_filter': ['route', 'route6', 'key-cert', 'role'], + config_override( + { + "sources": { + "TEST": { + "object_class_filter": ["route", "route6", "key-cert", "role"], + } } } - }) + ) mock_dh = Mock() test_data = [ SAMPLE_ROUTE, # Valid retained SAMPLE_ROUTE6, # Valid modified SAMPLE_ROLE, # Valid new object - SAMPLE_ROUTE.replace('TEST', 'BADSOURCE'), + SAMPLE_ROUTE.replace("TEST", "BADSOURCE"), SAMPLE_UNKNOWN_CLASS, SAMPLE_MALFORMED_PK, ] - test_input = '\n\n'.join(test_data) + test_input = "\n\n".join(test_data) - route_with_last_modified = SAMPLE_ROUTE + 'last-modified: 2020-01-01T00:00:00Z\n' + route_with_last_modified = SAMPLE_ROUTE + "last-modified: 2020-01-01T00:00:00Z\n" mock_query_result = [ { # Retained object (with format cleaning) # includes a last-modified which should be ignored in the comparison - 'rpsl_pk': '192.0.2.0/24AS65537', - 'object_class': 'route', - 'object_text': rpsl_object_from_text(route_with_last_modified).render_rpsl_text(), + "rpsl_pk": "192.0.2.0/24AS65537", + "object_class": "route", + "object_text": rpsl_object_from_text(route_with_last_modified).render_rpsl_text(), }, { # Modified object - 'rpsl_pk': '2001:DB8::/48AS65537', - 'object_class': 'route6', - 'object_text': SAMPLE_ROUTE6.replace('test-MNT', 'existing-mnt'), + "rpsl_pk": "2001:DB8::/48AS65537", + "object_class": "route6", + "object_text": SAMPLE_ROUTE6.replace("test-MNT", "existing-mnt"), }, { # Deleted object - 'rpsl_pk': 'rtrs-settest', - 'object_class': 'route-set', - 'object_text': SAMPLE_RTR_SET, + "rpsl_pk": "rtrs-settest", + "object_class": "route-set", + "object_text": SAMPLE_RTR_SET, }, ] mock_dh.execute_query = lambda query: mock_query_result with tempfile.NamedTemporaryFile() as fp: - fp.write(test_input.encode('utf-8')) + fp.write(test_input.encode("utf-8")) fp.seek(0) parser = MirrorUpdateFileImportParser( - source='TEST', + source="TEST", filename=fp.name, database_handler=mock_dh, ) parser.run_import() assert len(mock_dh.mock_calls) == 5 - assert mock_dh.mock_calls[0][0] == 'record_mirror_error' - assert mock_dh.mock_calls[1][0] == 'record_mirror_error' - assert mock_dh.mock_calls[2][0] == 'upsert_rpsl_object' - assert mock_dh.mock_calls[2][1][0].pk() == 'ROLE-TEST' - assert mock_dh.mock_calls[3][0] == 'delete_rpsl_object' - assert mock_dh.mock_calls[3][2]['source'] == 'TEST' - assert mock_dh.mock_calls[3][2]['rpsl_pk'] == 'rtrs-settest' - assert mock_dh.mock_calls[3][2]['object_class'] == 'route-set' - assert mock_dh.mock_calls[3][2]['origin'] == JournalEntryOrigin.synthetic_nrtm - assert mock_dh.mock_calls[4][0] == 'upsert_rpsl_object' - assert mock_dh.mock_calls[4][1][0].pk() == '2001:DB8::/48AS65537' - - assert 'Invalid source BADSOURCE for object' in caplog.text - assert 'Invalid address prefix' in caplog.text - assert 'File update for TEST: 6 objects read, 3 objects processed, 1 objects newly inserted, 1 objects newly deleted, 2 objects retained, of which 1 modified' in caplog.text - assert 'ignored 0 due to object_class_filter' in caplog.text - assert 'Ignored 1 objects found in file import for TEST due to unknown object classes' in caplog.text + assert mock_dh.mock_calls[0][0] == "record_mirror_error" + assert mock_dh.mock_calls[1][0] == "record_mirror_error" + assert mock_dh.mock_calls[2][0] == "upsert_rpsl_object" + assert mock_dh.mock_calls[2][1][0].pk() == "ROLE-TEST" + assert mock_dh.mock_calls[3][0] == "delete_rpsl_object" + assert mock_dh.mock_calls[3][2]["source"] == "TEST" + assert mock_dh.mock_calls[3][2]["rpsl_pk"] == "rtrs-settest" + assert mock_dh.mock_calls[3][2]["object_class"] == "route-set" + assert mock_dh.mock_calls[3][2]["origin"] == JournalEntryOrigin.synthetic_nrtm + assert mock_dh.mock_calls[4][0] == "upsert_rpsl_object" + assert mock_dh.mock_calls[4][1][0].pk() == "2001:DB8::/48AS65537" + + assert "Invalid source BADSOURCE for object" in caplog.text + assert "Invalid address prefix" in caplog.text + assert ( + "File update for TEST: 6 objects read, 3 objects processed, 1 objects newly inserted, 1 objects" + " newly deleted, 2 objects retained, of which 1 modified" + in caplog.text + ) + assert "ignored 0 due to object_class_filter" in caplog.text + assert "Ignored 1 objects found in file import for TEST due to unknown object classes" in caplog.text def test_direct_error_return(self, mock_scopefilter, config_override): - config_override({ - 'sources': { - 'TEST': {} - } - }) + config_override({"sources": {"TEST": {}}}) mock_dh = Mock() test_data = [ SAMPLE_UNKNOWN_CLASS, SAMPLE_MALFORMED_PK, ] - test_input = '\n\n'.join(test_data) + test_input = "\n\n".join(test_data) with tempfile.NamedTemporaryFile() as fp: - fp.write(test_input.encode('utf-8')) + fp.write(test_input.encode("utf-8")) fp.seek(0) parser = MirrorUpdateFileImportParser( - source='TEST', + source="TEST", filename=fp.name, database_handler=mock_dh, direct_error_return=True, ) - assert parser.run_import() == 'Unknown object class: foo-block' + assert parser.run_import() == "Unknown object class: foo-block" assert len(mock_dh.mock_calls) == 0 @@ -279,119 +310,121 @@ def test_direct_error_return(self, mock_scopefilter, config_override): class TestNRTMStreamParser: def test_test_parse_nrtm_v3_valid(self): mock_dh = Mock() - parser = NRTMStreamParser('TEST', SAMPLE_NRTM_V3, mock_dh) + parser = NRTMStreamParser("TEST", SAMPLE_NRTM_V3, mock_dh) self._assert_valid(parser) - assert flatten_mock_calls(mock_dh) == [['record_serial_newest_mirror', ('TEST', 11012701), {}]] + assert flatten_mock_calls(mock_dh) == [["record_serial_newest_mirror", ("TEST", 11012701), {}]] def test_test_parse_nrtm_v1_valid(self, config_override): - config_override({ - 'sources': { - 'TEST': { - 'object_class_filter': 'person', - 'strict_import_keycert_objects': True, + config_override( + { + "sources": { + "TEST": { + "object_class_filter": "person", + "strict_import_keycert_objects": True, + } } } - }) + ) mock_dh = Mock() - parser = NRTMStreamParser('TEST', SAMPLE_NRTM_V1, mock_dh) + parser = NRTMStreamParser("TEST", SAMPLE_NRTM_V1, mock_dh) self._assert_valid(parser) - assert flatten_mock_calls(mock_dh) == [['record_serial_newest_mirror', ('TEST', 11012701), {}]] + assert flatten_mock_calls(mock_dh) == [["record_serial_newest_mirror", ("TEST", 11012701), {}]] def test_test_parse_nrtm_v3_valid_serial_gap(self): mock_dh = Mock() - parser = NRTMStreamParser('TEST', SAMPLE_NRTM_V3_SERIAL_GAP, mock_dh) + parser = NRTMStreamParser("TEST", SAMPLE_NRTM_V3_SERIAL_GAP, mock_dh) self._assert_valid(parser) - assert flatten_mock_calls(mock_dh) == [['record_serial_newest_mirror', ('TEST', 11012703), {}]] + assert flatten_mock_calls(mock_dh) == [["record_serial_newest_mirror", ("TEST", 11012703), {}]] def test_test_parse_nrtm_v3_invalid_serial_out_of_order(self): mock_dh = Mock() with pytest.raises(ValueError) as ve: - NRTMStreamParser('TEST', SAMPLE_NRTM_V3_SERIAL_OUT_OF_ORDER, mock_dh) + NRTMStreamParser("TEST", SAMPLE_NRTM_V3_SERIAL_OUT_OF_ORDER, mock_dh) - error_msg = 'expected at least' + error_msg = "expected at least" assert error_msg in str(ve.value) assert len(mock_dh.mock_calls) == 1 - assert mock_dh.mock_calls[0][0] == 'record_mirror_error' - assert mock_dh.mock_calls[0][1][0] == 'TEST' + assert mock_dh.mock_calls[0][0] == "record_mirror_error" + assert mock_dh.mock_calls[0][1][0] == "TEST" assert error_msg in mock_dh.mock_calls[0][1][1] def test_test_parse_nrtm_v3_invalid_unexpected_source(self): mock_dh = Mock() with pytest.raises(ValueError) as ve: - NRTMStreamParser('BADSOURCE', SAMPLE_NRTM_V3, mock_dh) + NRTMStreamParser("BADSOURCE", SAMPLE_NRTM_V3, mock_dh) - error_msg = 'Invalid NRTM source in START line: expected BADSOURCE but found TEST ' + error_msg = "Invalid NRTM source in START line: expected BADSOURCE but found TEST " assert error_msg in str(ve.value) assert len(mock_dh.mock_calls) == 1 - assert mock_dh.mock_calls[0][0] == 'record_mirror_error' - assert mock_dh.mock_calls[0][1][0] == 'BADSOURCE' + assert mock_dh.mock_calls[0][0] == "record_mirror_error" + assert mock_dh.mock_calls[0][1][0] == "BADSOURCE" assert error_msg in mock_dh.mock_calls[0][1][1] def test_test_parse_nrtm_v1_invalid_too_many_items(self): mock_dh = Mock() with pytest.raises(ValueError) as ve: - NRTMStreamParser('TEST', SAMPLE_NRTM_V1_TOO_MANY_ITEMS, mock_dh) - error_msg = 'expected operations up to and including' + NRTMStreamParser("TEST", SAMPLE_NRTM_V1_TOO_MANY_ITEMS, mock_dh) + error_msg = "expected operations up to and including" assert error_msg in str(ve.value) assert len(mock_dh.mock_calls) == 1 - assert mock_dh.mock_calls[0][0] == 'record_mirror_error' - assert mock_dh.mock_calls[0][1][0] == 'TEST' + assert mock_dh.mock_calls[0][0] == "record_mirror_error" + assert mock_dh.mock_calls[0][1][0] == "TEST" assert error_msg in mock_dh.mock_calls[0][1][1] def test_test_parse_nrtm_invalid_invalid_version(self): mock_dh = Mock() with pytest.raises(ValueError) as ve: - NRTMStreamParser('TEST', SAMPLE_NRTM_INVALID_VERSION, mock_dh) + NRTMStreamParser("TEST", SAMPLE_NRTM_INVALID_VERSION, mock_dh) - error_msg = 'Invalid NRTM version 99 in START line' + error_msg = "Invalid NRTM version 99 in START line" assert error_msg in str(ve.value) assert len(mock_dh.mock_calls) == 1 - assert mock_dh.mock_calls[0][0] == 'record_mirror_error' - assert mock_dh.mock_calls[0][1][0] == 'TEST' + assert mock_dh.mock_calls[0][0] == "record_mirror_error" + assert mock_dh.mock_calls[0][1][0] == "TEST" assert error_msg in mock_dh.mock_calls[0][1][1] def test_test_parse_nrtm_invalid_multiple_start_lines(self): mock_dh = Mock() with pytest.raises(ValueError) as ve: - NRTMStreamParser('TEST', SAMPLE_NRTM_V3_INVALID_MULTIPLE_START_LINES, mock_dh) + NRTMStreamParser("TEST", SAMPLE_NRTM_V3_INVALID_MULTIPLE_START_LINES, mock_dh) - error_msg = 'Encountered second START line' + error_msg = "Encountered second START line" assert error_msg in str(ve.value) assert len(mock_dh.mock_calls) == 1 - assert mock_dh.mock_calls[0][0] == 'record_mirror_error' - assert mock_dh.mock_calls[0][1][0] == 'TEST' + assert mock_dh.mock_calls[0][0] == "record_mirror_error" + assert mock_dh.mock_calls[0][1][0] == "TEST" assert error_msg in mock_dh.mock_calls[0][1][1] def test_test_parse_nrtm_invalid_no_start_line(self): mock_dh = Mock() with pytest.raises(ValueError) as ve: - NRTMStreamParser('TEST', SAMPLE_NRTM_INVALID_NO_START_LINE, mock_dh) + NRTMStreamParser("TEST", SAMPLE_NRTM_INVALID_NO_START_LINE, mock_dh) - error_msg = 'Encountered operation before valid NRTM START line' + error_msg = "Encountered operation before valid NRTM START line" assert error_msg in str(ve.value) assert len(mock_dh.mock_calls) == 1 - assert mock_dh.mock_calls[0][0] == 'record_mirror_error' - assert mock_dh.mock_calls[0][1][0] == 'TEST' + assert mock_dh.mock_calls[0][0] == "record_mirror_error" + assert mock_dh.mock_calls[0][1][0] == "TEST" assert error_msg in mock_dh.mock_calls[0][1][1] def test_test_parse_nrtm_no_end(self): mock_dh = Mock() with pytest.raises(ValueError) as ve: - NRTMStreamParser('TEST', SAMPLE_NRTM_V3_NO_END, mock_dh) + NRTMStreamParser("TEST", SAMPLE_NRTM_V3_NO_END, mock_dh) - error_msg = 'last comment paragraph expected to be' + error_msg = "last comment paragraph expected to be" assert error_msg in str(ve.value) assert len(mock_dh.mock_calls) == 1 - assert mock_dh.mock_calls[0][0] == 'record_mirror_error' - assert mock_dh.mock_calls[0][1][0] == 'TEST' + assert mock_dh.mock_calls[0][0] == "record_mirror_error" + assert mock_dh.mock_calls[0][1][0] == "TEST" assert error_msg in mock_dh.mock_calls[0][1][1] def _assert_valid(self, parser: NRTMStreamParser): assert parser.operations[0].operation == DatabaseOperation.add_or_update assert parser.operations[0].serial == 11012700 - assert parser.operations[0].object_text == 'person: NRTM test\naddress: NowhereLand\nsource: TEST\n' + assert parser.operations[0].object_text == "person: NRTM test\naddress: NowhereLand\nsource: TEST\n" assert parser.operations[1].operation == DatabaseOperation.delete assert parser.operations[1].serial == 11012701 - assert parser.operations[1].object_text == 'inetnum: 192.0.2.0 - 192.0.2.255\nsource: TEST\n' + assert parser.operations[1].object_text == "inetnum: 192.0.2.0 - 192.0.2.255\nsource: TEST\n" diff --git a/irrd/mirroring/tests/test_scheduler.py b/irrd/mirroring/tests/test_scheduler.py index 4ee1793bb..0281d131d 100644 --- a/irrd/mirroring/tests/test_scheduler.py +++ b/irrd/mirroring/tests/test_scheduler.py @@ -1,48 +1,51 @@ -import time - import threading +import time -from ..scheduler import MirrorScheduler, ScheduledTaskProcess, MAX_SIMULTANEOUS_RUNS +from ..scheduler import MAX_SIMULTANEOUS_RUNS, MirrorScheduler, ScheduledTaskProcess thread_run_count = 0 class TestMirrorScheduler: def test_scheduler_database_readonly(self, monkeypatch, config_override): - monkeypatch.setattr('irrd.mirroring.scheduler.ScheduledTaskProcess', MockScheduledTaskProcess) + monkeypatch.setattr("irrd.mirroring.scheduler.ScheduledTaskProcess", MockScheduledTaskProcess) global thread_run_count thread_run_count = 0 - config_override({ - 'database_readonly': True, - 'sources': { - 'TEST': { - 'import_source': 'url', - 'import_timer': 0, - } + config_override( + { + "database_readonly": True, + "sources": { + "TEST": { + "import_source": "url", + "import_timer": 0, + } + }, } - }) + ) - monkeypatch.setattr('irrd.mirroring.scheduler.RPSLMirrorImportUpdateRunner', MockRunner) + monkeypatch.setattr("irrd.mirroring.scheduler.RPSLMirrorImportUpdateRunner", MockRunner) scheduler = MirrorScheduler() scheduler.run() assert thread_run_count == 0 def test_scheduler_runs_rpsl_import(self, monkeypatch, config_override): - monkeypatch.setattr('irrd.mirroring.scheduler.ScheduledTaskProcess', MockScheduledTaskProcess) + monkeypatch.setattr("irrd.mirroring.scheduler.ScheduledTaskProcess", MockScheduledTaskProcess) global thread_run_count thread_run_count = 0 - config_override({ - 'sources': { - 'TEST': { - 'import_source': 'url', - 'import_timer': 0, + config_override( + { + "sources": { + "TEST": { + "import_source": "url", + "import_timer": 0, + } } } - }) + ) - monkeypatch.setattr('irrd.mirroring.scheduler.RPSLMirrorImportUpdateRunner', MockRunner) + monkeypatch.setattr("irrd.mirroring.scheduler.RPSLMirrorImportUpdateRunner", MockRunner) MockRunner.run_sleep = True scheduler = MirrorScheduler() @@ -63,32 +66,34 @@ def test_scheduler_runs_rpsl_import(self, monkeypatch, config_override): assert len(scheduler.processes.items()) == 0 def test_scheduler_limits_simultaneous_runs(self, monkeypatch, config_override): - monkeypatch.setattr('irrd.mirroring.scheduler.ScheduledTaskProcess', MockScheduledTaskProcess) + monkeypatch.setattr("irrd.mirroring.scheduler.ScheduledTaskProcess", MockScheduledTaskProcess) global thread_run_count thread_run_count = 0 - config_override({ - 'sources': { - 'TEST': { - 'import_source': 'url', - 'import_timer': 0, - }, - 'TEST2': { - 'import_source': 'url', - 'import_timer': 0, - }, - 'TEST3': { - 'import_source': 'url', - 'import_timer': 0, - }, - 'TEST4': { - 'import_source': 'url', - 'import_timer': 0, - }, + config_override( + { + "sources": { + "TEST": { + "import_source": "url", + "import_timer": 0, + }, + "TEST2": { + "import_source": "url", + "import_timer": 0, + }, + "TEST3": { + "import_source": "url", + "import_timer": 0, + }, + "TEST4": { + "import_source": "url", + "import_timer": 0, + }, + } } - }) + ) - monkeypatch.setattr('irrd.mirroring.scheduler.RPSLMirrorImportUpdateRunner', MockRunner) + monkeypatch.setattr("irrd.mirroring.scheduler.RPSLMirrorImportUpdateRunner", MockRunner) MockRunner.run_sleep = False scheduler = MirrorScheduler() @@ -98,17 +103,13 @@ def test_scheduler_limits_simultaneous_runs(self, monkeypatch, config_override): assert thread_run_count == MAX_SIMULTANEOUS_RUNS def test_scheduler_runs_roa_import(self, monkeypatch, config_override): - monkeypatch.setattr('irrd.mirroring.scheduler.ScheduledTaskProcess', MockScheduledTaskProcess) + monkeypatch.setattr("irrd.mirroring.scheduler.ScheduledTaskProcess", MockScheduledTaskProcess) global thread_run_count thread_run_count = 0 - config_override({ - 'rpki': { - 'roa_source': 'https://example.com/roa.json' - } - }) + config_override({"rpki": {"roa_source": "https://example.com/roa.json"}}) - monkeypatch.setattr('irrd.mirroring.scheduler.ROAImportRunner', MockRunner) + monkeypatch.setattr("irrd.mirroring.scheduler.ROAImportRunner", MockRunner) MockRunner.run_sleep = True scheduler = MirrorScheduler() @@ -120,40 +121,46 @@ def test_scheduler_runs_roa_import(self, monkeypatch, config_override): assert thread_run_count == 1 def test_scheduler_runs_scopefilter(self, monkeypatch, config_override): - monkeypatch.setattr('irrd.mirroring.scheduler.ScheduledTaskProcess', MockScheduledTaskProcess) + monkeypatch.setattr("irrd.mirroring.scheduler.ScheduledTaskProcess", MockScheduledTaskProcess) global thread_run_count thread_run_count = 0 - config_override({ - 'rpki': {'roa_source': None}, - 'scopefilter': { - 'prefixes': ['192.0.2.0/24'], + config_override( + { + "rpki": {"roa_source": None}, + "scopefilter": { + "prefixes": ["192.0.2.0/24"], + }, } - }) + ) - monkeypatch.setattr('irrd.mirroring.scheduler.ScopeFilterUpdateRunner', MockRunner) + monkeypatch.setattr("irrd.mirroring.scheduler.ScopeFilterUpdateRunner", MockRunner) MockRunner.run_sleep = False scheduler = MirrorScheduler() scheduler.run() # Second run will not start the thread, as the config hasn't changed - config_override({ - 'rpki': {'roa_source': None}, - 'scopefilter': { - 'prefixes': ['192.0.2.0/24'], + config_override( + { + "rpki": {"roa_source": None}, + "scopefilter": { + "prefixes": ["192.0.2.0/24"], + }, } - }) + ) scheduler.run() time.sleep(0.2) assert thread_run_count == 1 - config_override({ - 'rpki': {'roa_source': None}, - 'scopefilter': { - 'asns': [23456], + config_override( + { + "rpki": {"roa_source": None}, + "scopefilter": { + "asns": [23456], + }, } - }) + ) # Should run now, because config has changed scheduler.update_process_state() @@ -161,15 +168,15 @@ def test_scheduler_runs_scopefilter(self, monkeypatch, config_override): time.sleep(0.2) assert thread_run_count == 2 - config_override({ - 'rpki': {'roa_source': None}, - 'scopefilter': { - 'asns': [23456], - }, - 'sources': { - 'TEST': {'scopefilter_excluded': True} - }, - }) + config_override( + { + "rpki": {"roa_source": None}, + "scopefilter": { + "asns": [23456], + }, + "sources": {"TEST": {"scopefilter_excluded": True}}, + } + ) # Should run again, because exclusions have changed scheduler.update_process_state() @@ -178,18 +185,20 @@ def test_scheduler_runs_scopefilter(self, monkeypatch, config_override): assert thread_run_count == 3 def test_scheduler_runs_route_preference(self, monkeypatch, config_override): - monkeypatch.setattr('irrd.mirroring.scheduler.ScheduledTaskProcess', MockScheduledTaskProcess) + monkeypatch.setattr("irrd.mirroring.scheduler.ScheduledTaskProcess", MockScheduledTaskProcess) global thread_run_count thread_run_count = 0 - config_override({ - 'rpki': {'roa_source': None}, - 'sources': { - 'TEST': {"route_object_preference": 200}, + config_override( + { + "rpki": {"roa_source": None}, + "sources": { + "TEST": {"route_object_preference": 200}, + }, } - }) + ) - monkeypatch.setattr('irrd.mirroring.scheduler.RoutePreferenceUpdateRunner', MockRunner) + monkeypatch.setattr("irrd.mirroring.scheduler.RoutePreferenceUpdateRunner", MockRunner) MockRunner.run_sleep = True scheduler = MirrorScheduler() @@ -201,20 +210,22 @@ def test_scheduler_runs_route_preference(self, monkeypatch, config_override): assert thread_run_count == 1 def test_scheduler_import_ignores_timer_not_expired(self, monkeypatch, config_override): - monkeypatch.setattr('irrd.mirroring.scheduler.ScheduledTaskProcess', MockScheduledTaskProcess) + monkeypatch.setattr("irrd.mirroring.scheduler.ScheduledTaskProcess", MockScheduledTaskProcess) global thread_run_count thread_run_count = 0 - config_override({ - 'sources': { - 'TEST': { - 'import_source': 'url', - 'import_timer': 100, + config_override( + { + "sources": { + "TEST": { + "import_source": "url", + "import_timer": 100, + } } } - }) + ) - monkeypatch.setattr('irrd.mirroring.scheduler.RPSLMirrorImportUpdateRunner', MockRunner) + monkeypatch.setattr("irrd.mirroring.scheduler.RPSLMirrorImportUpdateRunner", MockRunner) MockRunner.run_sleep = False scheduler = MirrorScheduler() @@ -228,20 +239,22 @@ def test_scheduler_import_ignores_timer_not_expired(self, monkeypatch, config_ov assert thread_run_count == 1 def test_scheduler_runs_export(self, monkeypatch, config_override): - monkeypatch.setattr('irrd.mirroring.scheduler.ScheduledTaskProcess', MockScheduledTaskProcess) + monkeypatch.setattr("irrd.mirroring.scheduler.ScheduledTaskProcess", MockScheduledTaskProcess) global thread_run_count thread_run_count = 0 - config_override({ - 'sources': { - 'TEST': { - 'export_destination': 'url', - 'export_timer': 0, + config_override( + { + "sources": { + "TEST": { + "export_destination": "url", + "export_timer": 0, + } } } - }) + ) - monkeypatch.setattr('irrd.mirroring.scheduler.SourceExportRunner', MockRunner) + monkeypatch.setattr("irrd.mirroring.scheduler.SourceExportRunner", MockRunner) MockRunner.run_sleep = True scheduler = MirrorScheduler() @@ -253,20 +266,22 @@ def test_scheduler_runs_export(self, monkeypatch, config_override): assert thread_run_count == 1 def test_scheduler_export_ignores_timer_not_expired(self, monkeypatch, config_override): - monkeypatch.setattr('irrd.mirroring.scheduler.ScheduledTaskProcess', MockScheduledTaskProcess) + monkeypatch.setattr("irrd.mirroring.scheduler.ScheduledTaskProcess", MockScheduledTaskProcess) global thread_run_count thread_run_count = 0 - config_override({ - 'sources': { - 'TEST': { - 'export_destination': 'url', - 'export_timer': 100, + config_override( + { + "sources": { + "TEST": { + "export_destination": "url", + "export_timer": 100, + } } } - }) + ) - monkeypatch.setattr('irrd.mirroring.scheduler.SourceExportRunner', MockRunner) + monkeypatch.setattr("irrd.mirroring.scheduler.SourceExportRunner", MockRunner) MockRunner.run_sleep = False scheduler = MirrorScheduler() @@ -285,7 +300,7 @@ def test_task(self): global thread_run_count thread_run_count = 0 MockRunner.run_sleep = True - ScheduledTaskProcess(runner=MockRunner('TEST'), name='test').run() + ScheduledTaskProcess(runner=MockRunner("TEST"), name="test").run() assert thread_run_count == 1 @@ -293,7 +308,7 @@ class MockRunner: run_sleep = True def __init__(self, source): - assert source in ['TEST', 'TEST2', 'TEST3', 'TEST4', 'RPKI', 'scopefilter', 'routepref'] + assert source in ["TEST", "TEST2", "TEST3", "TEST4", "RPKI", "scopefilter", "routepref"] def run(self): global thread_run_count diff --git a/irrd/routepref/routepref.py b/irrd/routepref/routepref.py index a08426e00..45ae34ef5 100644 --- a/irrd/routepref/routepref.py +++ b/irrd/routepref/routepref.py @@ -1,5 +1,5 @@ import logging -from typing import List, Dict, Iterable, Tuple, Optional +from typing import Dict, Iterable, List, Optional, Tuple import radix from IPy import IP @@ -8,8 +8,8 @@ from irrd.conf import get_setting from irrd.storage.database_handler import DatabaseHandler from irrd.storage.queries import RPSLDatabaseQuery -from .status import RoutePreferenceStatus +from .status import RoutePreferenceStatus logger = logging.getLogger(__name__) @@ -74,7 +74,9 @@ def validate_known_routes(self) -> Tuple[List[str], List[str]]: to_be_suppressed = [] for evaluated_node in self.rtree: search_args = {"packed": evaluated_node.packed, "masklen": evaluated_node.prefixlen} - overlapping_nodes = self.rtree.search_covered(**search_args) + self.rtree.search_covering(**search_args) + overlapping_nodes = self.rtree.search_covered(**search_args) + self.rtree.search_covering( + **search_args + ) for evaluated_key, (evaluated_preference, current_status) in evaluated_node.data.items(): new_status = self._evaluate_route(evaluated_preference, overlapping_nodes) @@ -85,7 +87,9 @@ def validate_known_routes(self) -> Tuple[List[str], List[str]]: to_be_visible.append(evaluated_key) return to_be_visible, to_be_suppressed - def _evaluate_route(self, route_preference: int, overlapping_nodes: List[RadixNode]) -> RoutePreferenceStatus: + def _evaluate_route( + self, route_preference: int, overlapping_nodes: List[RadixNode] + ) -> RoutePreferenceStatus: """ Given a preference, evaluate the correct state of a route based on a given list of overlapping nodes. diff --git a/irrd/routepref/tests/test_routepref.py b/irrd/routepref/tests/test_routepref.py index 46cb97c18..96a8fd49f 100644 --- a/irrd/routepref/tests/test_routepref.py +++ b/irrd/routepref/tests/test_routepref.py @@ -2,6 +2,7 @@ from irrd.storage.queries import RPSLDatabaseQuery from irrd.utils.test_utils import MockDatabaseHandler + from ..routepref import RoutePreferenceValidator, update_route_preference_status from ..status import RoutePreferenceStatus @@ -143,7 +144,9 @@ def test_update_route_preference_status(config_override): mock_dh.query_responses[RPSLDatabaseQuery] = iter(route_objects) update_route_preference_status(mock_dh) assert mock_dh.queries == [ - RPSLDatabaseQuery(column_names=expected_columns, ordered_by_sources=False).object_classes(object_classes), + RPSLDatabaseQuery(column_names=expected_columns, ordered_by_sources=False).object_classes( + object_classes + ), RPSLDatabaseQuery( enrich_columns, enable_ordering=False, diff --git a/irrd/rpki/importer.py b/irrd/rpki/importer.py index 49d3f7351..6a07766fb 100644 --- a/irrd/rpki/importer.py +++ b/irrd/rpki/importer.py @@ -1,21 +1,20 @@ +import logging from collections import defaultdict +from typing import Dict, List, Optional, Set import ujson - -import logging from IPy import IP, IPSet -from typing import List, Optional, Dict, Set from irrd.conf import RPKI_IRR_PSEUDO_SOURCE, get_setting from irrd.rpki.status import RPKIStatus -from irrd.rpsl.parser import RPSLObject, RPSL_ATTRIBUTE_TEXT_WIDTH +from irrd.rpsl.parser import RPSL_ATTRIBUTE_TEXT_WIDTH, RPSLObject from irrd.rpsl.rpsl_objects import RPSL_ROUTE_OBJECT_CLASS_FOR_IP_VERSION from irrd.scopefilter.validators import ScopeFilterValidator from irrd.storage.database_handler import DatabaseHandler from irrd.storage.models import JournalEntryOrigin from irrd.utils.validators import parse_as_number -SLURM_TRUST_ANCHOR = 'SLURM file' +SLURM_TRUST_ANCHOR = "SLURM file" logger = logging.getLogger(__name__) @@ -36,8 +35,8 @@ class ROADataImporter: database_handler.delete_all_roa_objects() database_handler.delete_all_rpsl_objects_with_journal(RPKI_IRR_PSEUDO_SOURCE) """ - def __init__(self, rpki_json_str: str, slurm_json_str: Optional[str], - database_handler: DatabaseHandler): + + def __init__(self, rpki_json_str: str, slurm_json_str: Optional[str], database_handler: DatabaseHandler): self.roa_objs: List[ROA] = [] self._filtered_asns: Set[int] = set() self._filtered_prefixes: IPSet = IPSet() @@ -51,9 +50,9 @@ def __init__(self, rpki_json_str: str, slurm_json_str: Optional[str], for roa_dict in self._roa_dicts: try: - _, asn = parse_as_number(roa_dict['asn'], permit_plain=True) - prefix = IP(roa_dict['prefix']) - ta = roa_dict['ta'] + _, asn = parse_as_number(roa_dict["asn"], permit_plain=True) + prefix = IP(roa_dict["prefix"]) + ta = roa_dict["ta"] if ta != SLURM_TRUST_ANCHOR: if asn in self._filtered_asns: continue @@ -62,13 +61,13 @@ def __init__(self, rpki_json_str: str, slurm_json_str: Optional[str], if any([prefix in self._filtered_combined.get(asn, [])]): continue - roa_obj = ROA(prefix, asn, roa_dict['maxLength'], ta) + roa_obj = ROA(prefix, asn, roa_dict["maxLength"], ta) except KeyError as ke: - msg = f'Unable to parse ROA record: missing key {ke} -- full record: {roa_dict}' + msg = f"Unable to parse ROA record: missing key {ke} -- full record: {roa_dict}" logger.error(msg) raise ROAParserException(msg) except ValueError as ve: - msg = f'Invalid value in ROA or SLURM: {ve}' + msg = f"Invalid value in ROA or SLURM: {ve}" logger.error(msg) raise ROAParserException(msg) @@ -78,9 +77,9 @@ def __init__(self, rpki_json_str: str, slurm_json_str: Optional[str], def _load_roa_dicts(self, rpki_json_str: str) -> None: """Load the ROAs from the JSON string into self._roa_dicts""" try: - self._roa_dicts = ujson.loads(rpki_json_str)['roas'] + self._roa_dicts = ujson.loads(rpki_json_str)["roas"] except ValueError as error: - msg = f'Unable to parse ROA input: invalid JSON: {error}' + msg = f"Unable to parse ROA input: invalid JSON: {error}" logger.error(msg) raise ROAParserException(msg) except KeyError: @@ -107,32 +106,34 @@ def _load_slurm(self, slurm_json_str: str): This must be called after _load_roa_dicts() """ slurm = ujson.loads(slurm_json_str) - version = slurm.get('slurmVersion') + version = slurm.get("slurmVersion") if version != 1: - msg = f'SLURM data has invalid version: {version}' + msg = f"SLURM data has invalid version: {version}" logger.error(msg) raise ROAParserException(msg) - filters = slurm.get('validationOutputFilters', {}).get('prefixFilters', []) + filters = slurm.get("validationOutputFilters", {}).get("prefixFilters", []) for item in filters: - if 'asn' in item and 'prefix' not in item: - self._filtered_asns.add(int(item['asn'])) - if 'asn' not in item and 'prefix' in item: - self._filtered_prefixes.add(IP(item['prefix'])) - if 'asn' in item and 'prefix' in item: - self._filtered_combined[int(item['asn'])].add(IP(item['prefix'])) - - assertions = slurm.get('locallyAddedAssertions', {}).get('prefixAssertions', []) + if "asn" in item and "prefix" not in item: + self._filtered_asns.add(int(item["asn"])) + if "asn" not in item and "prefix" in item: + self._filtered_prefixes.add(IP(item["prefix"])) + if "asn" in item and "prefix" in item: + self._filtered_combined[int(item["asn"])].add(IP(item["prefix"])) + + assertions = slurm.get("locallyAddedAssertions", {}).get("prefixAssertions", []) for assertion in assertions: - max_length = assertion.get('maxPrefixLength') + max_length = assertion.get("maxPrefixLength") if max_length is None: - max_length = IP(assertion['prefix']).prefixlen() - self._roa_dicts.append({ - 'asn': 'AS' + str(assertion['asn']), - 'prefix': assertion['prefix'], - 'maxLength': max_length, - 'ta': SLURM_TRUST_ANCHOR, - }) + max_length = IP(assertion["prefix"]).prefixlen() + self._roa_dicts.append( + { + "asn": "AS" + str(assertion["asn"]), + "prefix": assertion["prefix"], + "maxLength": max_length, + "ta": SLURM_TRUST_ANCHOR, + } + ) class ROA: @@ -142,6 +143,7 @@ class ROA: This is used when (re-)importing all ROAs, to save the data to the DB, and by the BulkRouteROAValidator when validating all existing routes. """ + def __init__(self, prefix: IP, asn: int, max_length: str, trust_anchor: str): try: self.prefix = prefix @@ -150,13 +152,15 @@ def __init__(self, prefix: IP, asn: int, max_length: str, trust_anchor: str): self.max_length = int(max_length) self.trust_anchor = trust_anchor except ValueError as ve: - msg = f'Invalid value in ROA: {ve}' + msg = f"Invalid value in ROA: {ve}" logger.error(msg) raise ROAParserException(msg) if self.max_length < self.prefix.prefixlen(): - msg = f'Invalid ROA: prefix size {self.prefix.prefixlen()} is smaller than max length {max_length} in ' \ - f'ROA for {self.prefix} / AS{self.asn}' + msg = ( + f"Invalid ROA: prefix size {self.prefix.prefixlen()} is smaller than max length" + f" {max_length} in ROA for {self.prefix} / AS{self.asn}" + ) logger.error(msg) raise ROAParserException(msg) @@ -179,8 +183,9 @@ def save(self, database_handler: DatabaseHandler, scopefilter_validator: ScopeFi trust_anchor=self.trust_anchor, scopefilter_validator=scopefilter_validator, ) - database_handler.upsert_rpsl_object(self._rpsl_object, JournalEntryOrigin.pseudo_irr, - rpsl_guaranteed_no_existing=True) + database_handler.upsert_rpsl_object( + self._rpsl_object, JournalEntryOrigin.pseudo_irr, rpsl_guaranteed_no_existing=True + ) class RPSLObjectFromROA(RPSLObject): @@ -189,9 +194,17 @@ class RPSLObjectFromROA(RPSLObject): an RPKI pseudo-IRR object. It overrides the API in relevant parts. """ + # noinspection PyMissingConstructor - def __init__(self, prefix: IP, prefix_str: str, asn: int, max_length: int, trust_anchor: str, - scopefilter_validator: ScopeFilterValidator): + def __init__( + self, + prefix: IP, + prefix_str: str, + asn: int, + max_length: int, + trust_anchor: str, + scopefilter_validator: ScopeFilterValidator, + ): self.prefix = prefix self.prefix_str = prefix_str self.asn = asn @@ -207,9 +220,9 @@ def __init__(self, prefix: IP, prefix_str: str, asn: int, max_length: int, trust self.rpki_status = RPKIStatus.valid self.parsed_data = { self.rpsl_object_class: self.prefix_str, - 'origin': 'AS' + str(self.asn), - 'source': RPKI_IRR_PSEUDO_SOURCE, - 'rpki_max_length': max_length, + "origin": "AS" + str(self.asn), + "source": RPKI_IRR_PSEUDO_SOURCE, + "rpki_max_length": max_length, } self.scopefilter_status, _ = scopefilter_validator.validate_rpsl_object(self) @@ -217,19 +230,22 @@ def source(self): return RPKI_IRR_PSEUDO_SOURCE def pk(self): - return f'{self.prefix_str}AS{self.asn}/ML{self.max_length}' + return f"{self.prefix_str}AS{self.asn}/ML{self.max_length}" def render_rpsl_text(self, last_modified=None): - object_class_display = f'{self.rpsl_object_class}:'.ljust(RPSL_ATTRIBUTE_TEXT_WIDTH) - remarks_fill = RPSL_ATTRIBUTE_TEXT_WIDTH * ' ' - remarks = get_setting('rpki.pseudo_irr_remarks').replace('\n', '\n' + remarks_fill).strip() + object_class_display = f"{self.rpsl_object_class}:".ljust(RPSL_ATTRIBUTE_TEXT_WIDTH) + remarks_fill = RPSL_ATTRIBUTE_TEXT_WIDTH * " " + remarks = get_setting("rpki.pseudo_irr_remarks").replace("\n", "\n" + remarks_fill).strip() remarks = remarks.format(asn=self.asn, prefix=self.prefix_str) - rpsl_object_text = f""" + rpsl_object_text = ( + f""" {object_class_display}{self.prefix_str} descr: RPKI ROA for {self.prefix_str} / AS{self.asn} remarks: {remarks} max-length: {self.max_length} origin: AS{self.asn} source: {RPKI_IRR_PSEUDO_SOURCE} # Trust Anchor: {self.trust_anchor} -""".strip() + '\n' +""".strip() + + "\n" + ) return rpsl_object_text diff --git a/irrd/rpki/notifications.py b/irrd/rpki/notifications.py index a28db498f..c064744bc 100644 --- a/irrd/rpki/notifications.py +++ b/irrd/rpki/notifications.py @@ -14,8 +14,9 @@ logger = logging.getLogger(__name__) -def notify_rpki_invalid_owners(database_handler: DatabaseHandler, - rpsl_dicts_now_invalid: List[Dict[str, str]]) -> int: +def notify_rpki_invalid_owners( + database_handler: DatabaseHandler, rpsl_dicts_now_invalid: List[Dict[str, str]] +) -> int: """ Notify the owners/contacts of newly RPKI invalid objects. @@ -24,20 +25,20 @@ def notify_rpki_invalid_owners(database_handler: DatabaseHandler, tech-c or admin-c, of any maintainer of the object. One email is sent per email address. """ - if not get_setting('rpki.notify_invalid_enabled'): + if not get_setting("rpki.notify_invalid_enabled"): return 0 rpsl_objs = [] for obj in rpsl_dicts_now_invalid: - source = obj['source'] - authoritative = get_setting(f'sources.{source}.authoritative') - if authoritative and obj['rpki_status'] == RPKIStatus.invalid: - rpsl_objs.append(rpsl_object_from_text(obj['object_text'])) + source = obj["source"] + authoritative = get_setting(f"sources.{source}.authoritative") + if authoritative and obj["rpki_status"] == RPKIStatus.invalid: + rpsl_objs.append(rpsl_object_from_text(obj["object_text"])) if not rpsl_objs: return 0 - sources = {obj.parsed_data['source'] for obj in rpsl_objs} + sources = {obj.parsed_data["source"] for obj in rpsl_objs} mntner_emails_by_source = {} for source in sources: # For each source, a multi-step process is run to fill this @@ -45,31 +46,44 @@ def notify_rpki_invalid_owners(database_handler: DatabaseHandler, mntner_emails = defaultdict(set) # Step 1: retrieve all relevant maintainers from the DB - mntner_pks = set(itertools.chain(*[ - obj.parsed_data.get('mnt-by', []) - for obj in rpsl_objs - if obj.parsed_data['source'] == source - ])) - query = RPSLDatabaseQuery(['rpsl_pk', 'parsed_data']).sources([source]).rpsl_pks(mntner_pks).object_classes(['mntner']) + mntner_pks = set( + itertools.chain( + *[ + obj.parsed_data.get("mnt-by", []) + for obj in rpsl_objs + if obj.parsed_data["source"] == source + ] + ) + ) + query = ( + RPSLDatabaseQuery(["rpsl_pk", "parsed_data"]) + .sources([source]) + .rpsl_pks(mntner_pks) + .object_classes(["mntner"]) + ) mntners = list(database_handler.execute_query(query)) # Step 2: any mnt-nfy on these maintainers is a contact address for mntner in mntners: - mntner_emails[mntner['rpsl_pk']].update(mntner['parsed_data'].get('mnt-nfy', [])) + mntner_emails[mntner["rpsl_pk"]].update(mntner["parsed_data"].get("mnt-nfy", [])) # Step 3: extract the contact handles for each maintainer mntner_contacts = { - m['rpsl_pk']: m['parsed_data'].get('tech-c', []) + m['parsed_data'].get('admin-c', []) + m["rpsl_pk"]: m["parsed_data"].get("tech-c", []) + m["parsed_data"].get("admin-c", []) for m in mntners } # Step 4: retrieve all these contacts from the DB in bulk, # and extract their e-mail addresses contact_pks = set(itertools.chain(*mntner_contacts.values())) - query = RPSLDatabaseQuery(['rpsl_pk', 'parsed_data']).sources([source]).rpsl_pks(contact_pks).object_classes(['role', 'person']) + query = ( + RPSLDatabaseQuery(["rpsl_pk", "parsed_data"]) + .sources([source]) + .rpsl_pks(contact_pks) + .object_classes(["role", "person"]) + ) contacts = { - r['rpsl_pk']: r['parsed_data'].get('e-mail', []) - for r in database_handler.execute_query(query) + r["rpsl_pk"]: r["parsed_data"].get("e-mail", []) for r in database_handler.execute_query(query) } # Step 5: use the contacts per maintainer, and emails per contact @@ -88,8 +102,8 @@ def notify_rpki_invalid_owners(database_handler: DatabaseHandler, # addresses they need to be sent to. objs_per_email: Dict[str, Set[RPSLObject]] = defaultdict(set) for rpsl_obj in rpsl_objs: - mntners = rpsl_obj.parsed_data.get('mnt-by', []) - source = rpsl_obj.parsed_data['source'] + mntners = rpsl_obj.parsed_data.get("mnt-by", []) + source = rpsl_obj.parsed_data["source"] for mntner_pk in mntners: try: for email in mntner_emails_by_source[source][mntner_pk]: @@ -97,20 +111,20 @@ def notify_rpki_invalid_owners(database_handler: DatabaseHandler, except KeyError: # pragma: no cover pass - header_template = get_setting('rpki.notify_invalid_header', '') - subject_template = get_setting('rpki.notify_invalid_subject', '').replace('\n', ' ') + header_template = get_setting("rpki.notify_invalid_header", "") + subject_template = get_setting("rpki.notify_invalid_subject", "").replace("\n", " ") for email, objs in objs_per_email.items(): - sources_str = ', '.join({obj.parsed_data['source'] for obj in objs}) + sources_str = ", ".join({obj.parsed_data["source"] for obj in objs}) subject = subject_template.format(sources_str=sources_str, object_count=len(objs)) body = header_template.format(sources_str=sources_str, object_count=len(objs)) - body += '\nThe following objects are affected:\n' - body += '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n' + body += "\nThe following objects are affected:\n" + body += "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n" for rpsl_obj in objs: - body += rpsl_obj.render_rpsl_text() + '\n' - body += '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~' + body += rpsl_obj.render_rpsl_text() + "\n" + body += "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" try: send_email(email, subject, body) except Exception as e: # pragma: no cover - logger.warning(f'Unable to send RPKI invalid notification to {email}: {e}') + logger.warning(f"Unable to send RPKI invalid notification to {email}: {e}") return len(objs_per_email.keys()) diff --git a/irrd/rpki/status.py b/irrd/rpki/status.py index 0393cab45..d00304a7d 100644 --- a/irrd/rpki/status.py +++ b/irrd/rpki/status.py @@ -3,9 +3,9 @@ @enum.unique class RPKIStatus(enum.Enum): - valid = 'VALID' - invalid = 'INVALID' - not_found = 'NOT_FOUND' + valid = "VALID" + invalid = "INVALID" + not_found = "NOT_FOUND" @classmethod def is_visible(cls, status: "RPKIStatus"): diff --git a/irrd/rpki/tests/test_importer.py b/irrd/rpki/tests/test_importer.py index 0364e1f08..f5b6025a5 100644 --- a/irrd/rpki/tests/test_importer.py +++ b/irrd/rpki/tests/test_importer.py @@ -9,15 +9,15 @@ from irrd.scopefilter.validators import ScopeFilterValidator from irrd.storage.database_handler import DatabaseHandler from irrd.utils.test_utils import flatten_mock_calls + from ..importer import ROADataImporter, ROAParserException @pytest.fixture() def mock_scopefilter(monkeypatch): mock_scopefilter = Mock(spec=ScopeFilterValidator) - monkeypatch.setattr('irrd.rpki.importer.ScopeFilterValidator', - lambda: mock_scopefilter) - mock_scopefilter.validate_rpsl_object = lambda obj: (ScopeFilterStatus.out_scope_as, '') + monkeypatch.setattr("irrd.rpki.importer.ScopeFilterValidator", lambda: mock_scopefilter) + mock_scopefilter.validate_rpsl_object = lambda obj: (ScopeFilterStatus.out_scope_as, "") class TestROAImportProcess: @@ -27,129 +27,178 @@ def test_valid_process(self, monkeypatch, mock_scopefilter): mock_dh = Mock(spec=DatabaseHandler) - rpki_data = ujson.dumps({ - "roas": [{ - "asn": "64496", - "prefix": "192.0.2.0/24", - "maxLength": 26, - "ta": "APNIC RPKI Root" - }, { - "asn": "AS64497", - "prefix": "2001:db8::/32", - "maxLength": 40, - "ta": "RIPE NCC RPKI Root" - }, { - # Filtered out by SLURM due to origin - "asn": "64498", - "prefix": "192.0.2.0/24", - "maxLength": 32, - "ta": "APNIC RPKI Root" - }, { - # Filtered out by SLURM due to prefix - "asn": "AS64496", - "prefix": "203.0.113.0/25", - "maxLength": 26, - "ta": "APNIC RPKI Root" - }, { - # Filtered out by SLURM due to prefix - "asn": "AS64497", - "prefix": "203.0.113.0/26", - "maxLength": 26, - "ta": "APNIC RPKI Root" - }, { - # Filtered out by SLURM due to prefix plus origin - "asn": "AS64497", - "prefix": "203.0.113.128/26", - "maxLength": 26, - "ta": "APNIC RPKI Root" - }] - }) - - slurm_data = ujson.dumps({ - "slurmVersion": 1, - "validationOutputFilters": { - "prefixFilters": [ + rpki_data = ujson.dumps( + { + "roas": [ + {"asn": "64496", "prefix": "192.0.2.0/24", "maxLength": 26, "ta": "APNIC RPKI Root"}, { - "prefix": "203.0.113.0/25", - "comment": "All VRPs encompassed by prefix", + "asn": "AS64497", + "prefix": "2001:db8::/32", + "maxLength": 40, + "ta": "RIPE NCC RPKI Root", }, { - "asn": 64498, - "comment": "All VRPs matching ASN", + # Filtered out by SLURM due to origin + "asn": "64498", + "prefix": "192.0.2.0/24", + "maxLength": 32, + "ta": "APNIC RPKI Root", }, { - "prefix": "203.0.113.128/25", - "asn": 64497, - "comment": "All VRPs encompassed by prefix, matching ASN", + # Filtered out by SLURM due to prefix + "asn": "AS64496", + "prefix": "203.0.113.0/25", + "maxLength": 26, + "ta": "APNIC RPKI Root", }, { - # This filters out nothing, the ROA for this prefix has AS 64496 - "prefix": "192.0.2.0/24", - "asn": 64497, - "comment": "All VRPs encompassed by prefix, matching ASN", + # Filtered out by SLURM due to prefix + "asn": "AS64497", + "prefix": "203.0.113.0/26", + "maxLength": 26, + "ta": "APNIC RPKI Root", }, { - # This should not filter out the assertion for 198.51.100/24 - "prefix": "198.51.100.0/24", - "asn": 64496, - "comment": "All VRPs encompassed by prefix, matching ASN", - } - ], - }, - "locallyAddedAssertions": { - "prefixAssertions": [ - { - "asn": 64496, - "prefix": "198.51.100.0/24", - "comment": "My other important route", + # Filtered out by SLURM due to prefix plus origin + "asn": "AS64497", + "prefix": "203.0.113.128/26", + "maxLength": 26, + "ta": "APNIC RPKI Root", }, - { - "asn": 64497, - "prefix": "2001:DB8::/32", - "maxPrefixLength": 48, - "comment": "My other important de-aggregated routes", - } - ], + ] } - }) + ) + + slurm_data = ujson.dumps( + { + "slurmVersion": 1, + "validationOutputFilters": { + "prefixFilters": [ + { + "prefix": "203.0.113.0/25", + "comment": "All VRPs encompassed by prefix", + }, + { + "asn": 64498, + "comment": "All VRPs matching ASN", + }, + { + "prefix": "203.0.113.128/25", + "asn": 64497, + "comment": "All VRPs encompassed by prefix, matching ASN", + }, + { + # This filters out nothing, the ROA for this prefix has AS 64496 + "prefix": "192.0.2.0/24", + "asn": 64497, + "comment": "All VRPs encompassed by prefix, matching ASN", + }, + { + # This should not filter out the assertion for 198.51.100/24 + "prefix": "198.51.100.0/24", + "asn": 64496, + "comment": "All VRPs encompassed by prefix, matching ASN", + }, + ], + }, + "locallyAddedAssertions": { + "prefixAssertions": [ + { + "asn": 64496, + "prefix": "198.51.100.0/24", + "comment": "My other important route", + }, + { + "asn": 64497, + "prefix": "2001:DB8::/32", + "maxPrefixLength": 48, + "comment": "My other important de-aggregated routes", + }, + ], + }, + } + ) roa_importer = ROADataImporter(rpki_data, slurm_data, mock_dh) assert flatten_mock_calls(mock_dh, flatten_objects=True) == [ - ['insert_roa_object', (), - {'ip_version': 4, 'prefix_str': '192.0.2.0/24', 'asn': 64496, - 'max_length': 26, 'trust_anchor': 'APNIC RPKI Root'}], - ['upsert_rpsl_object', - ('route/192.0.2.0/24AS64496/ML26/RPKI', 'JournalEntryOrigin.pseudo_irr'), - {'rpsl_guaranteed_no_existing': True}], - ['insert_roa_object', (), - {'ip_version': 6, 'prefix_str': '2001:db8::/32', 'asn': 64497, - 'max_length': 40, 'trust_anchor': 'RIPE NCC RPKI Root'}], - ['upsert_rpsl_object', - ('route6/2001:db8::/32AS64497/ML40/RPKI', 'JournalEntryOrigin.pseudo_irr'), - {'rpsl_guaranteed_no_existing': True}], - ['insert_roa_object', (), - {'ip_version': 4, 'prefix_str': '198.51.100.0/24', 'asn': 64496, - 'max_length': 24, 'trust_anchor': 'SLURM file'}], - ['upsert_rpsl_object', - ('route/198.51.100.0/24AS64496/ML24/RPKI', 'JournalEntryOrigin.pseudo_irr'), - {'rpsl_guaranteed_no_existing': True}], - ['insert_roa_object', (), - {'ip_version': 6, 'prefix_str': '2001:db8::/32', 'asn': 64497, - 'max_length': 48, 'trust_anchor': 'SLURM file'}], - ['upsert_rpsl_object', - ('route6/2001:db8::/32AS64497/ML48/RPKI', 'JournalEntryOrigin.pseudo_irr'), - {'rpsl_guaranteed_no_existing': True}], + [ + "insert_roa_object", + (), + { + "ip_version": 4, + "prefix_str": "192.0.2.0/24", + "asn": 64496, + "max_length": 26, + "trust_anchor": "APNIC RPKI Root", + }, + ], + [ + "upsert_rpsl_object", + ("route/192.0.2.0/24AS64496/ML26/RPKI", "JournalEntryOrigin.pseudo_irr"), + {"rpsl_guaranteed_no_existing": True}, + ], + [ + "insert_roa_object", + (), + { + "ip_version": 6, + "prefix_str": "2001:db8::/32", + "asn": 64497, + "max_length": 40, + "trust_anchor": "RIPE NCC RPKI Root", + }, + ], + [ + "upsert_rpsl_object", + ("route6/2001:db8::/32AS64497/ML40/RPKI", "JournalEntryOrigin.pseudo_irr"), + {"rpsl_guaranteed_no_existing": True}, + ], + [ + "insert_roa_object", + (), + { + "ip_version": 4, + "prefix_str": "198.51.100.0/24", + "asn": 64496, + "max_length": 24, + "trust_anchor": "SLURM file", + }, + ], + [ + "upsert_rpsl_object", + ("route/198.51.100.0/24AS64496/ML24/RPKI", "JournalEntryOrigin.pseudo_irr"), + {"rpsl_guaranteed_no_existing": True}, + ], + [ + "insert_roa_object", + (), + { + "ip_version": 6, + "prefix_str": "2001:db8::/32", + "asn": 64497, + "max_length": 48, + "trust_anchor": "SLURM file", + }, + ], + [ + "upsert_rpsl_object", + ("route6/2001:db8::/32AS64497/ML48/RPKI", "JournalEntryOrigin.pseudo_irr"), + {"rpsl_guaranteed_no_existing": True}, + ], ] assert roa_importer.roa_objs[0]._rpsl_object.scopefilter_status == ScopeFilterStatus.out_scope_as assert roa_importer.roa_objs[0]._rpsl_object.source() == RPKI_IRR_PSEUDO_SOURCE assert roa_importer.roa_objs[0]._rpsl_object.parsed_data == { - 'origin': 'AS64496', - 'route': '192.0.2.0/24', - 'rpki_max_length': 26, - 'source': 'RPKI', + "origin": "AS64496", + "route": "192.0.2.0/24", + "rpki_max_length": 26, + "source": "RPKI", } - assert roa_importer.roa_objs[0]._rpsl_object.render_rpsl_text() == textwrap.dedent(""" + assert ( + roa_importer.roa_objs[0]._rpsl_object.render_rpsl_text() + == textwrap.dedent( + """ route: 192.0.2.0/24 descr: RPKI ROA for 192.0.2.0/24 / AS64496 remarks: This AS64496 route object represents routing data retrieved @@ -158,17 +207,20 @@ def test_valid_process(self, monkeypatch, mock_scopefilter): max-length: 26 origin: AS64496 source: RPKI # Trust Anchor: APNIC RPKI Root - """).strip() + '\n' + """ + ).strip() + + "\n" + ) def test_invalid_rpki_json(self, monkeypatch, mock_scopefilter): mock_dh = Mock(spec=DatabaseHandler) with pytest.raises(ROAParserException) as rpe: - ROADataImporter('invalid', None, mock_dh) + ROADataImporter("invalid", None, mock_dh) - assert 'Unable to parse ROA input: invalid JSON: Expected object or value' in str(rpe.value) + assert "Unable to parse ROA input: invalid JSON: Expected object or value" in str(rpe.value) - data = ujson.dumps({'invalid root': 42}) + data = ujson.dumps({"invalid root": 42}) with pytest.raises(ROAParserException) as rpe: ROADataImporter(data, None, mock_dh) assert 'Unable to parse ROA input: root key "roas" not found' in str(rpe.value) @@ -178,64 +230,48 @@ def test_invalid_rpki_json(self, monkeypatch, mock_scopefilter): def test_invalid_data_in_roa(self, monkeypatch, mock_scopefilter): mock_dh = Mock(spec=DatabaseHandler) - data = ujson.dumps({ - "roas": [{ - "asn": "AS64496", - "prefix": "192.0.2.999/24", - "maxLength": 26, - "ta": "APNIC RPKI Root" - }] - }) + data = ujson.dumps( + { + "roas": [ + {"asn": "AS64496", "prefix": "192.0.2.999/24", "maxLength": 26, "ta": "APNIC RPKI Root"} + ] + } + ) with pytest.raises(ROAParserException) as rpe: ROADataImporter(data, None, mock_dh) - assert "Invalid value in ROA or SLURM: '192.0.2.999': single byte must be 0 <= byte < 256" in str(rpe.value) - - data = ujson.dumps({ - "roas": [{ - "asn": "ASx", - "prefix": "192.0.2.0/24", - "maxLength": 24, - "ta": "APNIC RPKI Root" - }] - }) + assert "Invalid value in ROA or SLURM: '192.0.2.999': single byte must be 0 <= byte < 256" in str( + rpe.value + ) + + data = ujson.dumps( + {"roas": [{"asn": "ASx", "prefix": "192.0.2.0/24", "maxLength": 24, "ta": "APNIC RPKI Root"}]} + ) with pytest.raises(ROAParserException) as rpe: ROADataImporter(data, None, mock_dh) - assert 'Invalid AS number ASX: number part is not numeric' in str(rpe.value) - - data = ujson.dumps({ - "roas": [{ - "prefix": "192.0.2.0/24", - "maxLength": 24, - "ta": "APNIC RPKI Root" - }] - }) + assert "Invalid AS number ASX: number part is not numeric" in str(rpe.value) + + data = ujson.dumps({"roas": [{"prefix": "192.0.2.0/24", "maxLength": 24, "ta": "APNIC RPKI Root"}]}) with pytest.raises(ROAParserException) as rpe: ROADataImporter(data, None, mock_dh) assert "Unable to parse ROA record: missing key 'asn'" in str(rpe.value) - data = ujson.dumps({ - "roas": [{ - "asn": "AS64496", - "prefix": "192.0.2.0/24", - "maxLength": 22, - "ta": "APNIC RPKI Root" - }] - }) + data = ujson.dumps( + {"roas": [{"asn": "AS64496", "prefix": "192.0.2.0/24", "maxLength": 22, "ta": "APNIC RPKI Root"}]} + ) with pytest.raises(ROAParserException) as rpe: ROADataImporter(data, None, mock_dh) - assert 'Invalid ROA: prefix size 24 is smaller than max length 22' in str(rpe.value) - - data = ujson.dumps({ - "roas": [{ - "asn": "AS64496", - "prefix": "192.0.2.0/24", - "maxLength": 'xx', - "ta": "APNIC RPKI Root" - }] - }) + assert "Invalid ROA: prefix size 24 is smaller than max length 22" in str(rpe.value) + + data = ujson.dumps( + { + "roas": [ + {"asn": "AS64496", "prefix": "192.0.2.0/24", "maxLength": "xx", "ta": "APNIC RPKI Root"} + ] + } + ) with pytest.raises(ROAParserException) as rpe: ROADataImporter(data, None, mock_dh) - assert 'xx' in str(rpe.value) + assert "xx" in str(rpe.value) assert flatten_mock_calls(mock_dh) == [] @@ -245,4 +281,4 @@ def test_invalid_slurm_version(self, monkeypatch, mock_scopefilter): with pytest.raises(ROAParserException) as rpe: ROADataImporter('{"roas": []}', '{"slurmVersion": 2}', mock_dh) - assert 'SLURM data has invalid version: 2' in str(rpe.value) + assert "SLURM data has invalid version: 2" in str(rpe.value) diff --git a/irrd/rpki/tests/test_notifications.py b/irrd/rpki/tests/test_notifications.py index 568d44bc4..a37bdbb86 100644 --- a/irrd/rpki/tests/test_notifications.py +++ b/irrd/rpki/tests/test_notifications.py @@ -1,65 +1,91 @@ -# flake8: noqa: W293 import textwrap from unittest.mock import Mock -from ..notifications import notify_rpki_invalid_owners -from irrd.utils.test_utils import flatten_mock_calls from irrd.storage.database_handler import DatabaseHandler from irrd.utils.rpsl_samples import SAMPLE_ROUTE, SAMPLE_ROUTE6 -from ..status import RPKIStatus +from irrd.utils.test_utils import flatten_mock_calls + from ...storage.queries import RPSLDatabaseQuery from ...utils.email import send_email +from ..notifications import notify_rpki_invalid_owners +from ..status import RPKIStatus class TestNotifyRPKIInvalidOwners: def test_notify_regular(self, monkeypatch, config_override): - config_override({ - 'sources': {'TEST': {'authoritative': True}}, - 'rpki': {'notify_invalid_enabled': True}, - }) + config_override( + { + "sources": {"TEST": {"authoritative": True}}, + "rpki": {"notify_invalid_enabled": True}, + } + ) mock_dh = Mock(spec=DatabaseHandler) mock_dq = Mock(spec=RPSLDatabaseQuery) - monkeypatch.setattr('irrd.rpki.notifications.RPSLDatabaseQuery', lambda columns: mock_dq) + monkeypatch.setattr("irrd.rpki.notifications.RPSLDatabaseQuery", lambda columns: mock_dq) mock_email = Mock(spec=send_email) - monkeypatch.setattr('irrd.rpki.notifications.send_email', mock_email) + monkeypatch.setattr("irrd.rpki.notifications.send_email", mock_email) rpsl_dicts_now_invalid = [ - {'source': 'TEST', 'object_text': SAMPLE_ROUTE + 'mnt-by: DOESNOTEXIST-MNT\nMISSING-DATA-MNT\n', 'rpki_status': RPKIStatus.invalid}, - {'source': 'TEST', 'object_text': SAMPLE_ROUTE6, 'rpki_status': RPKIStatus.valid}, # should be ignored - {'source': 'TEST2', 'object_text': SAMPLE_ROUTE6, 'rpki_status': RPKIStatus.invalid}, # should be ignored + { + "source": "TEST", + "object_text": SAMPLE_ROUTE + "mnt-by: DOESNOTEXIST-MNT\nMISSING-DATA-MNT\n", + "rpki_status": RPKIStatus.invalid, + }, + { + "source": "TEST", + "object_text": SAMPLE_ROUTE6, + "rpki_status": RPKIStatus.valid, + }, # should be ignored + { + "source": "TEST2", + "object_text": SAMPLE_ROUTE6, + "rpki_status": RPKIStatus.invalid, + }, # should be ignored ] - query_results = iter([ + query_results = iter( [ - {'rpsl_pk': 'TEST-MNT', 'parsed_data': { - 'mnt-nfy': ['mnt-nfy@example.com'], - 'tech-c': ['PERSON-TEST', 'DOESNOTEXIST-TEST'] - }}, - {'rpsl_pk': 'MISSING-DATA-MNT', 'parsed_data': {}}, - ], - [ - {'rpsl_pk': 'PERSON-TEST', 'parsed_data': {'e-mail': ['person@xample.com', 'person2@example.com']}}, - {'rpsl_pk': 'IGNORED-TEST', 'parsed_data': {'e-mail': ['ignored@xample.com']}}, - ], - ]) + [ + { + "rpsl_pk": "TEST-MNT", + "parsed_data": { + "mnt-nfy": ["mnt-nfy@example.com"], + "tech-c": ["PERSON-TEST", "DOESNOTEXIST-TEST"], + }, + }, + {"rpsl_pk": "MISSING-DATA-MNT", "parsed_data": {}}, + ], + [ + { + "rpsl_pk": "PERSON-TEST", + "parsed_data": {"e-mail": ["person@xample.com", "person2@example.com"]}, + }, + {"rpsl_pk": "IGNORED-TEST", "parsed_data": {"e-mail": ["ignored@xample.com"]}}, + ], + ] + ) mock_dh.execute_query = lambda q: next(query_results) notified = notify_rpki_invalid_owners(mock_dh, rpsl_dicts_now_invalid) assert notified == 3 assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST'],), {}], - ['rpsl_pks', ({'TEST-MNT', 'DOESNOTEXIST-MNT'},), {}], - ['object_classes', (['mntner'],), {}], - ['sources', (['TEST'],), {}], - ['rpsl_pks', ({'PERSON-TEST', 'DOESNOTEXIST-TEST'},), {}], - ['object_classes', (['role', 'person'],), {}]] + ["sources", (["TEST"],), {}], + ["rpsl_pks", ({"TEST-MNT", "DOESNOTEXIST-MNT"},), {}], + ["object_classes", (["mntner"],), {}], + ["sources", (["TEST"],), {}], + ["rpsl_pks", ({"PERSON-TEST", "DOESNOTEXIST-TEST"},), {}], + ["object_classes", (["role", "person"],), {}], + ] assert len(mock_email.mock_calls) == 3 actual_recipients = {call[1][0] for call in mock_email.mock_calls} - expected_recipients = {'person@xample.com', 'person2@example.com', 'mnt-nfy@example.com'} + expected_recipients = {"person@xample.com", "person2@example.com", "mnt-nfy@example.com"} assert actual_recipients == expected_recipients - assert mock_email.mock_calls[0][1][1] == 'route(6) objects in TEST marked RPKI invalid' - assert mock_email.mock_calls[0][1][2] == textwrap.dedent(""" + assert mock_email.mock_calls[0][1][1] == "route(6) objects in TEST marked RPKI invalid" + assert ( + mock_email.mock_calls[0][1][2] + == textwrap.dedent( + """ This is to notify that 1 route(6) objects for which you are a contact have been marked as RPKI invalid. This concerns objects in the TEST database. @@ -98,19 +124,23 @@ def test_notify_regular(self, monkeypatch, config_override): mnt-by: DOESNOTEXIST-MNT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - """).strip() + """ + ).strip() + ) def test_notify_disabled(self, monkeypatch, config_override): - config_override({ - 'sources': {'TEST': {'authoritative': True}}, - 'rpki': {'notify_invalid_enabled': False}, - }) + config_override( + { + "sources": {"TEST": {"authoritative": True}}, + "rpki": {"notify_invalid_enabled": False}, + } + ) mock_dh = Mock(spec=DatabaseHandler) mock_email = Mock() - monkeypatch.setattr('irrd.rpki.notifications.send_email', mock_email) + monkeypatch.setattr("irrd.rpki.notifications.send_email", mock_email) rpsl_dicts_now_invalid = [ - {'source': 'TEST', 'object_text': SAMPLE_ROUTE6, 'rpki_status': RPKIStatus.invalid}, + {"source": "TEST", "object_text": SAMPLE_ROUTE6, "rpki_status": RPKIStatus.invalid}, ] notified = notify_rpki_invalid_owners(mock_dh, rpsl_dicts_now_invalid) @@ -118,17 +148,19 @@ def test_notify_disabled(self, monkeypatch, config_override): assert len(mock_email.mock_calls) == 0 def test_notify_no_relevant_objects(self, monkeypatch, config_override): - config_override({ - 'sources': {'TEST': {'authoritative': True}}, - 'rpki': {'notify_invalid_enabled': True}, - }) + config_override( + { + "sources": {"TEST": {"authoritative": True}}, + "rpki": {"notify_invalid_enabled": True}, + } + ) mock_dh = Mock(spec=DatabaseHandler) mock_email = Mock() - monkeypatch.setattr('irrd.rpki.notifications.send_email', mock_email) + monkeypatch.setattr("irrd.rpki.notifications.send_email", mock_email) rpsl_dicts_now_invalid = [ # should be ignored - {'source': 'TEST2', 'object_text': SAMPLE_ROUTE6, 'rpki_status': RPKIStatus.invalid}, + {"source": "TEST2", "object_text": SAMPLE_ROUTE6, "rpki_status": RPKIStatus.invalid}, ] notified = notify_rpki_invalid_owners(mock_dh, rpsl_dicts_now_invalid) diff --git a/irrd/rpki/tests/test_validators.py b/irrd/rpki/tests/test_validators.py index 7057cb9cb..aab2e7e4a 100644 --- a/irrd/rpki/tests/test_validators.py +++ b/irrd/rpki/tests/test_validators.py @@ -1,10 +1,12 @@ -from IPy import IP from unittest.mock import Mock +from IPy import IP + from irrd.conf import RPKI_IRR_PSEUDO_SOURCE from irrd.storage.database_handler import DatabaseHandler -from irrd.storage.queries import RPSLDatabaseQuery, ROADatabaseObjectQuery +from irrd.storage.queries import ROADatabaseObjectQuery, RPSLDatabaseQuery from irrd.utils.test_utils import flatten_mock_calls + from ..importer import ROA from ..status import RPKIStatus from ..validators import BulkRouteROAValidator, SingleRouteROAValidator @@ -12,317 +14,357 @@ class TestBulkRouteROAValidator: def test_validate_routes_from_roa_objs(self, monkeypatch, config_override): - config_override({ - 'sources': {'TEST1': {}, 'TEST2': {}, RPKI_IRR_PSEUDO_SOURCE: {}, - 'SOURCE-EXCLUDED': {'rpki_excluded': True}} - }) + config_override( + { + "sources": { + "TEST1": {}, + "TEST2": {}, + RPKI_IRR_PSEUDO_SOURCE: {}, + "SOURCE-EXCLUDED": {"rpki_excluded": True}, + } + } + ) mock_dh = Mock(spec=DatabaseHandler) mock_dq = Mock(spec=RPSLDatabaseQuery) - monkeypatch.setattr('irrd.rpki.validators.RPSLDatabaseQuery', - lambda column_names, enable_ordering: mock_dq) + monkeypatch.setattr( + "irrd.rpki.validators.RPSLDatabaseQuery", lambda column_names, enable_ordering: mock_dq + ) - mock_query_result = iter([ + mock_query_result = iter( [ - { - 'pk': 'pk_route_v4_d0_l24', - 'rpsl_pk': 'pk_route_v4_d0_l24', - 'ip_version': 4, - 'ip_first': '192.0.2.0', - 'prefix_length': 24, - 'asn_first': 65546, - 'rpki_status': RPKIStatus.not_found, - 'source': 'TEST1', - }, - { - 'pk': 'pk_route_v4_d0_l25', - 'rpsl_pk': 'pk_route_v4_d0_l25', - 'ip_version': 4, - 'ip_first': '192.0.2.0', - 'prefix_length': 25, - 'asn_first': 65546, - 'rpki_status': RPKIStatus.not_found, - 'source': 'TEST1', - }, - { - # This route is valid, but as the state is already valid, - # it should not be included in the response. - 'pk': 'pk_route_v4_d0_l28', - 'rpsl_pk': 'pk_route_v4_d0_l28', - 'ip_version': 4, - 'ip_first': '192.0.2.0', - 'prefix_length': 27, - 'asn_first': 65546, - 'rpki_status': RPKIStatus.valid, - 'source': 'TEST1', - }, - { - 'pk': 'pk_route_v4_d64_l32', - 'rpsl_pk': 'pk_route_v4_d64_l32', - 'ip_version': 4, - 'ip_first': '192.0.2.64', - 'prefix_length': 32, - 'asn_first': 65546, - 'rpki_status': RPKIStatus.valid, - 'source': 'TEST1', - }, - { - 'pk': 'pk_route_v4_d128_l25', - 'rpsl_pk': 'pk_route_v4_d128_l25', - 'ip_version': 4, - 'ip_first': '192.0.2.128', - 'prefix_length': 25, - 'asn_first': 65547, - 'rpki_status': RPKIStatus.valid, - 'source': 'TEST1', - }, - { - # RPKI invalid, but should be ignored. - 'pk': 'pk_route_v4_d128_l26_rpki', - 'rpsl_pk': 'pk_route_v4_d128_l26', - 'ip_version': 4, - 'ip_first': '192.0.2.128', - 'prefix_length': 26, - 'asn_first': 65547, - 'rpki_status': RPKIStatus.invalid, - 'source': RPKI_IRR_PSEUDO_SOURCE, - }, - { - # RPKI invalid, but should be not_found because of source. - 'pk': 'pk_route_v4_d128_l26_excluded', - 'rpsl_pk': 'pk_route_v4_d128_l26_excluded', - 'ip_version': 4, - 'ip_first': '192.0.2.128', - 'prefix_length': 26, - 'asn_first': 65547, - 'rpki_status': RPKIStatus.valid, - 'source': 'SOURCE-EXCLUDED', - }, - { - 'pk': 'pk_route_v6', - 'rpsl_pk': 'pk_route_v6', - 'ip_version': 6, - 'ip_first': '2001:db8::', - 'prefix_length': 32, - 'asn_first': 65547, - 'rpki_status': RPKIStatus.invalid, - 'source': 'TEST1', - }, - { - # Should not match any ROA - ROAs for a subset - # exist, but those should not be included - 'pk': 'pk_route_v4_no_roa', - 'rpsl_pk': 'pk_route_v4_no_roa', - 'ip_version': 4, - 'ip_first': '192.0.2.0', - 'prefix_length': 23, - 'asn_first': 65549, - 'rpki_status': RPKIStatus.valid, - 'source': 'TEST1', - }, - { - 'pk': 'pk_route_v4_roa_as0', - 'rpsl_pk': 'pk_route_v4_roa_as0', - 'ip_version': 4, - 'ip_first': '203.0.113.1', - 'prefix_length': 32, - 'asn_first': 65547, - 'rpki_status': RPKIStatus.not_found, - 'source': 'TEST1', - }, - ], [ - { - 'pk': 'pk_route_v4_d0_l24', - 'object_text': 'object text', - 'object_class': 'route', - }, - { - 'pk': 'pk_route_v4_d0_l25', - 'object_text': 'object text', - 'object_class': 'route', - }, - { - 'pk': 'pk_route_v4_d64_l32', - 'object_text': 'object text', - 'object_class': 'route', - }, - { - 'pk': 'pk_route_v4_d128_l25', - 'object_text': 'object text', - 'object_class': 'route', - }, - { - 'pk': 'pk_route_v4_d128_l26_rpki', - 'object_text': 'object text', - 'object_class': 'route', - }, - { - 'pk': 'pk_route_v4_d128_l26_excluded', - 'object_text': 'object text', - 'object_class': 'route', - }, - { - 'pk': 'pk_route_v6', - 'object_text': 'object text', - 'object_class': 'route', - }, - { - 'pk': 'pk_route_v4_no_roa', - 'object_text': 'object text', - 'object_class': 'route', - }, - { - 'pk': 'pk_route_v4_roa_as0', - 'object_text': 'object text', - 'object_class': 'route', - }, + [ + { + "pk": "pk_route_v4_d0_l24", + "rpsl_pk": "pk_route_v4_d0_l24", + "ip_version": 4, + "ip_first": "192.0.2.0", + "prefix_length": 24, + "asn_first": 65546, + "rpki_status": RPKIStatus.not_found, + "source": "TEST1", + }, + { + "pk": "pk_route_v4_d0_l25", + "rpsl_pk": "pk_route_v4_d0_l25", + "ip_version": 4, + "ip_first": "192.0.2.0", + "prefix_length": 25, + "asn_first": 65546, + "rpki_status": RPKIStatus.not_found, + "source": "TEST1", + }, + { + # This route is valid, but as the state is already valid, + # it should not be included in the response. + "pk": "pk_route_v4_d0_l28", + "rpsl_pk": "pk_route_v4_d0_l28", + "ip_version": 4, + "ip_first": "192.0.2.0", + "prefix_length": 27, + "asn_first": 65546, + "rpki_status": RPKIStatus.valid, + "source": "TEST1", + }, + { + "pk": "pk_route_v4_d64_l32", + "rpsl_pk": "pk_route_v4_d64_l32", + "ip_version": 4, + "ip_first": "192.0.2.64", + "prefix_length": 32, + "asn_first": 65546, + "rpki_status": RPKIStatus.valid, + "source": "TEST1", + }, + { + "pk": "pk_route_v4_d128_l25", + "rpsl_pk": "pk_route_v4_d128_l25", + "ip_version": 4, + "ip_first": "192.0.2.128", + "prefix_length": 25, + "asn_first": 65547, + "rpki_status": RPKIStatus.valid, + "source": "TEST1", + }, + { + # RPKI invalid, but should be ignored. + "pk": "pk_route_v4_d128_l26_rpki", + "rpsl_pk": "pk_route_v4_d128_l26", + "ip_version": 4, + "ip_first": "192.0.2.128", + "prefix_length": 26, + "asn_first": 65547, + "rpki_status": RPKIStatus.invalid, + "source": RPKI_IRR_PSEUDO_SOURCE, + }, + { + # RPKI invalid, but should be not_found because of source. + "pk": "pk_route_v4_d128_l26_excluded", + "rpsl_pk": "pk_route_v4_d128_l26_excluded", + "ip_version": 4, + "ip_first": "192.0.2.128", + "prefix_length": 26, + "asn_first": 65547, + "rpki_status": RPKIStatus.valid, + "source": "SOURCE-EXCLUDED", + }, + { + "pk": "pk_route_v6", + "rpsl_pk": "pk_route_v6", + "ip_version": 6, + "ip_first": "2001:db8::", + "prefix_length": 32, + "asn_first": 65547, + "rpki_status": RPKIStatus.invalid, + "source": "TEST1", + }, + { + # Should not match any ROA - ROAs for a subset + # exist, but those should not be included + "pk": "pk_route_v4_no_roa", + "rpsl_pk": "pk_route_v4_no_roa", + "ip_version": 4, + "ip_first": "192.0.2.0", + "prefix_length": 23, + "asn_first": 65549, + "rpki_status": RPKIStatus.valid, + "source": "TEST1", + }, + { + "pk": "pk_route_v4_roa_as0", + "rpsl_pk": "pk_route_v4_roa_as0", + "ip_version": 4, + "ip_first": "203.0.113.1", + "prefix_length": 32, + "asn_first": 65547, + "rpki_status": RPKIStatus.not_found, + "source": "TEST1", + }, + ], + [ + { + "pk": "pk_route_v4_d0_l24", + "object_text": "object text", + "object_class": "route", + }, + { + "pk": "pk_route_v4_d0_l25", + "object_text": "object text", + "object_class": "route", + }, + { + "pk": "pk_route_v4_d64_l32", + "object_text": "object text", + "object_class": "route", + }, + { + "pk": "pk_route_v4_d128_l25", + "object_text": "object text", + "object_class": "route", + }, + { + "pk": "pk_route_v4_d128_l26_rpki", + "object_text": "object text", + "object_class": "route", + }, + { + "pk": "pk_route_v4_d128_l26_excluded", + "object_text": "object text", + "object_class": "route", + }, + { + "pk": "pk_route_v6", + "object_text": "object text", + "object_class": "route", + }, + { + "pk": "pk_route_v4_no_roa", + "object_text": "object text", + "object_class": "route", + }, + { + "pk": "pk_route_v4_roa_as0", + "object_text": "object text", + "object_class": "route", + }, + ], ] - ]) + ) mock_dh.execute_query = lambda query: next(mock_query_result) roas = [ # Valid for pk_route_v4_d0_l25 and pk_route_v4_d0_l24 # - the others have incorrect origin or are too small. - ROA(IP('192.0.2.0/24'), 65546, '28', 'TEST TA'), + ROA(IP("192.0.2.0/24"), 65546, "28", "TEST TA"), # Matches the origin of pk_route_v4_d128_l25, # but not max_length. - ROA(IP('192.0.2.0/24'), 65547, '24', 'TEST TA'), + ROA(IP("192.0.2.0/24"), 65547, "24", "TEST TA"), # Matches pk_route_v6, but not max_length. - ROA(IP('2001:db8::/30'), 65547, '30', 'TEST TA'), + ROA(IP("2001:db8::/30"), 65547, "30", "TEST TA"), # Matches pk_route_v6, but not on origin. - ROA(IP('2001:db8::/32'), 65548, '32', 'TEST TA'), + ROA(IP("2001:db8::/32"), 65548, "32", "TEST TA"), # Matches pk_route_v6 - ROA(IP('2001:db8::/32'), 65547, '64', 'TEST TA'), + ROA(IP("2001:db8::/32"), 65547, "64", "TEST TA"), # Matches no routes, no effect - ROA(IP('203.0.113.0/32'), 65547, '32', 'TEST TA'), + ROA(IP("203.0.113.0/32"), 65547, "32", "TEST TA"), # AS0 can not match - ROA(IP('203.0.113.1/32'), 0, '32', 'TEST TA'), + ROA(IP("203.0.113.1/32"), 0, "32", "TEST TA"), ] - result = BulkRouteROAValidator(mock_dh, roas).validate_all_routes(sources=['TEST1']) + result = BulkRouteROAValidator(mock_dh, roas).validate_all_routes(sources=["TEST1"]) new_valid_objs, new_invalid_objs, new_unknown_objs = result - assert {o['rpsl_pk'] for o in new_valid_objs} == {'pk_route_v6', 'pk_route_v4_d0_l25', 'pk_route_v4_d0_l24'} - assert [o['object_class'] for o in new_valid_objs] == ['route', 'route', 'route'] - assert [o['object_text'] for o in new_valid_objs] == ['object text', 'object text', 'object text'] - assert {o['rpsl_pk'] for o in new_invalid_objs} == {'pk_route_v4_d64_l32', 'pk_route_v4_d128_l25', 'pk_route_v4_roa_as0'} - assert {o['rpsl_pk'] for o in new_unknown_objs} == {'pk_route_v4_no_roa', 'pk_route_v4_d128_l26_excluded'} + assert {o["rpsl_pk"] for o in new_valid_objs} == { + "pk_route_v6", + "pk_route_v4_d0_l25", + "pk_route_v4_d0_l24", + } + assert [o["object_class"] for o in new_valid_objs] == ["route", "route", "route"] + assert [o["object_text"] for o in new_valid_objs] == ["object text", "object text", "object text"] + assert {o["rpsl_pk"] for o in new_invalid_objs} == { + "pk_route_v4_d64_l32", + "pk_route_v4_d128_l25", + "pk_route_v4_roa_as0", + } + assert {o["rpsl_pk"] for o in new_unknown_objs} == { + "pk_route_v4_no_roa", + "pk_route_v4_d128_l26_excluded", + } assert flatten_mock_calls(mock_dq) == [ - ['object_classes', (['route', 'route6'],), {}], - ['sources', (['TEST1'],), {}], - ['pks', (['pk_route_v4_d0_l24', 'pk_route_v4_d0_l25', 'pk_route_v6', 'pk_route_v4_d64_l32', 'pk_route_v4_d128_l25', 'pk_route_v4_roa_as0', 'pk_route_v4_d128_l26_excluded', 'pk_route_v4_no_roa'], ), {}], + ["object_classes", (["route", "route6"],), {}], + ["sources", (["TEST1"],), {}], + [ + "pks", + ( + [ + "pk_route_v4_d0_l24", + "pk_route_v4_d0_l25", + "pk_route_v6", + "pk_route_v4_d64_l32", + "pk_route_v4_d128_l25", + "pk_route_v4_roa_as0", + "pk_route_v4_d128_l26_excluded", + "pk_route_v4_no_roa", + ], + ), + {}, + ], ] def test_validate_routes_with_roa_from_database(self, monkeypatch, config_override): - config_override({ - 'sources': {'TEST1': {}, 'TEST2': {}, RPKI_IRR_PSEUDO_SOURCE: {}} - }) + config_override({"sources": {"TEST1": {}, "TEST2": {}, RPKI_IRR_PSEUDO_SOURCE: {}}}) mock_dh = Mock(spec=DatabaseHandler) mock_dq = Mock(spec=RPSLDatabaseQuery) - monkeypatch.setattr('irrd.rpki.validators.RPSLDatabaseQuery', - lambda column_names, enable_ordering: mock_dq) + monkeypatch.setattr( + "irrd.rpki.validators.RPSLDatabaseQuery", lambda column_names, enable_ordering: mock_dq + ) mock_rq = Mock(spec=ROADatabaseObjectQuery) - monkeypatch.setattr('irrd.rpki.validators.ROADatabaseObjectQuery', - lambda: mock_rq) + monkeypatch.setattr("irrd.rpki.validators.ROADatabaseObjectQuery", lambda: mock_rq) - mock_query_result = iter([ - [ # ROAs: - { - 'prefix': '192.0.2.0/24', - 'asn': 65546, - 'max_length': 25, - 'ip_version': 4, - }, - { - 'prefix': '192.0.2.0/24', - 'asn': 65547, - 'max_length': 24, - 'ip_version': 4, - }, - ], [ # RPSL objects: - { - 'pk': 'pk1', - 'rpsl_pk': 'pk_route_v4_d0_l25', - 'ip_version': 4, - 'ip_first': '192.0.2.0', - 'prefix_length': 25, - 'asn_first': 65546, - 'rpki_status': RPKIStatus.not_found, - 'source': 'TEST1', - }, - ], [ - { - 'pk': 'pk1', - 'object_class': 'route', - 'object_text': 'object text', - }, + mock_query_result = iter( + [ + [ # ROAs: + { + "prefix": "192.0.2.0/24", + "asn": 65546, + "max_length": 25, + "ip_version": 4, + }, + { + "prefix": "192.0.2.0/24", + "asn": 65547, + "max_length": 24, + "ip_version": 4, + }, + ], + [ # RPSL objects: + { + "pk": "pk1", + "rpsl_pk": "pk_route_v4_d0_l25", + "ip_version": 4, + "ip_first": "192.0.2.0", + "prefix_length": 25, + "asn_first": 65546, + "rpki_status": RPKIStatus.not_found, + "source": "TEST1", + }, + ], + [ + { + "pk": "pk1", + "object_class": "route", + "object_text": "object text", + }, + ], ] - ]) + ) mock_dh.execute_query = lambda query: next(mock_query_result) - result = BulkRouteROAValidator(mock_dh).validate_all_routes(sources=['TEST1']) + result = BulkRouteROAValidator(mock_dh).validate_all_routes(sources=["TEST1"]) new_valid_pks, new_invalid_pks, new_unknown_pks = result - assert {o['rpsl_pk'] for o in new_valid_pks} == {'pk_route_v4_d0_l25'} - assert {o['object_text'] for o in new_valid_pks} == {'object text'} + assert {o["rpsl_pk"] for o in new_valid_pks} == {"pk_route_v4_d0_l25"} + assert {o["object_text"] for o in new_valid_pks} == {"object text"} assert new_invalid_pks == list() assert new_unknown_pks == list() assert flatten_mock_calls(mock_dq) == [ - ['object_classes', (['route', 'route6'],), {}], - ['sources', (['TEST1'],), {}], - ['pks', (['pk1'],), {}], + ["object_classes", (["route", "route6"],), {}], + ["sources", (["TEST1"],), {}], + ["pks", (["pk1"],), {}], ] assert flatten_mock_calls(mock_rq) == [] # No filters applied class TestSingleRouteROAValidator: def test_validator_normal_roa(self, monkeypatch, config_override): - config_override({ - 'sources': {'SOURCE-EXCLUDED': {'rpki_excluded': True}} - }) + config_override({"sources": {"SOURCE-EXCLUDED": {"rpki_excluded": True}}}) mock_dh = Mock(spec=DatabaseHandler) mock_rq = Mock(spec=ROADatabaseObjectQuery) - monkeypatch.setattr('irrd.rpki.validators.ROADatabaseObjectQuery', lambda: mock_rq) + monkeypatch.setattr("irrd.rpki.validators.ROADatabaseObjectQuery", lambda: mock_rq) - roa_response = [{ - 'asn': 65548, - 'max_length': 25, - }] + roa_response = [ + { + "asn": 65548, + "max_length": 25, + } + ] mock_dh.execute_query = lambda q: roa_response validator = SingleRouteROAValidator(mock_dh) - assert validator.validate_route(IP('192.0.2.0/24'), 65548, 'TEST1') == RPKIStatus.valid - assert validator.validate_route(IP('192.0.2.0/24'), 65548, 'SOURCE-EXCLUDED') == RPKIStatus.not_found - assert validator.validate_route(IP('192.0.2.0/24'), 65549, 'TEST1') == RPKIStatus.invalid - assert validator.validate_route(IP('192.0.2.0/24'), 65549, 'SOURCE-EXCLUDED') == RPKIStatus.not_found - assert validator.validate_route(IP('192.0.2.0/26'), 65548, 'TEST1') == RPKIStatus.invalid + assert validator.validate_route(IP("192.0.2.0/24"), 65548, "TEST1") == RPKIStatus.valid + assert validator.validate_route(IP("192.0.2.0/24"), 65548, "SOURCE-EXCLUDED") == RPKIStatus.not_found + assert validator.validate_route(IP("192.0.2.0/24"), 65549, "TEST1") == RPKIStatus.invalid + assert validator.validate_route(IP("192.0.2.0/24"), 65549, "SOURCE-EXCLUDED") == RPKIStatus.not_found + assert validator.validate_route(IP("192.0.2.0/26"), 65548, "TEST1") == RPKIStatus.invalid assert flatten_mock_calls(mock_rq) == [ - ['ip_less_specific_or_exact', (IP('192.0.2.0/24'),), {}], - ['ip_less_specific_or_exact', (IP('192.0.2.0/24'),), {}], - ['ip_less_specific_or_exact', (IP('192.0.2.0/26'),), {}], + ["ip_less_specific_or_exact", (IP("192.0.2.0/24"),), {}], + ["ip_less_specific_or_exact", (IP("192.0.2.0/24"),), {}], + ["ip_less_specific_or_exact", (IP("192.0.2.0/26"),), {}], ] def test_validator_as0_roa(self, monkeypatch): mock_dh = Mock(spec=DatabaseHandler) mock_rq = Mock(spec=ROADatabaseObjectQuery) - monkeypatch.setattr('irrd.rpki.validators.ROADatabaseObjectQuery', lambda: mock_rq) + monkeypatch.setattr("irrd.rpki.validators.ROADatabaseObjectQuery", lambda: mock_rq) - roa_response = [{ - 'asn': 0, - 'max_length': 25, - }] + roa_response = [ + { + "asn": 0, + "max_length": 25, + } + ] mock_dh.execute_query = lambda q: roa_response validator = SingleRouteROAValidator(mock_dh) - assert validator.validate_route(IP('192.0.2.0/24'), 65548, 'TEST1') == RPKIStatus.invalid + assert validator.validate_route(IP("192.0.2.0/24"), 65548, "TEST1") == RPKIStatus.invalid def test_validator_no_matching_roa(self, monkeypatch): mock_dh = Mock(spec=DatabaseHandler) mock_rq = Mock(spec=ROADatabaseObjectQuery) - monkeypatch.setattr('irrd.rpki.validators.ROADatabaseObjectQuery', lambda: mock_rq) + monkeypatch.setattr("irrd.rpki.validators.ROADatabaseObjectQuery", lambda: mock_rq) mock_dh.execute_query = lambda q: [] validator = SingleRouteROAValidator(mock_dh) - assert validator.validate_route(IP('192.0.2.0/24'), 65548, 'TEST1') == RPKIStatus.not_found - assert validator.validate_route(IP('192.0.2.0/24'), 65549, 'TEST1') == RPKIStatus.not_found - assert validator.validate_route(IP('192.0.2.0/26'), 65548, 'TEST1') == RPKIStatus.not_found + assert validator.validate_route(IP("192.0.2.0/24"), 65548, "TEST1") == RPKIStatus.not_found + assert validator.validate_route(IP("192.0.2.0/24"), 65549, "TEST1") == RPKIStatus.not_found + assert validator.validate_route(IP("192.0.2.0/26"), 65548, "TEST1") == RPKIStatus.not_found diff --git a/irrd/rpki/validators.py b/irrd/rpki/validators.py index 321aca387..bd5b91ac1 100644 --- a/irrd/rpki/validators.py +++ b/irrd/rpki/validators.py @@ -1,14 +1,15 @@ -import datrie -from collections import defaultdict - import codecs import socket +from collections import defaultdict +from typing import Dict, List, Optional, Tuple + +import datrie from IPy import IP -from typing import Optional, List, Tuple, Dict from irrd.conf import RPKI_IRR_PSEUDO_SOURCE, get_setting from irrd.storage.database_handler import DatabaseHandler -from irrd.storage.queries import RPSLDatabaseQuery, ROADatabaseObjectQuery +from irrd.storage.queries import ROADatabaseObjectQuery, RPSLDatabaseQuery + from .importer import ROA from .status import RPKIStatus @@ -58,19 +59,20 @@ def __init__(self, dh: DatabaseHandler, roas: Optional[List[ROA]] = None): self.database_handler = dh self.excluded_sources = [] - for source, settings in get_setting('sources', {}).items(): - if settings.get('rpki_excluded'): + for source, settings in get_setting("sources", {}).items(): + if settings.get("rpki_excluded"): self.excluded_sources.append(source) - self.roa_tree4 = datrie.Trie('01') - self.roa_tree6 = datrie.Trie('01') + self.roa_tree4 = datrie.Trie("01") + self.roa_tree6 = datrie.Trie("01") if roas is None: self._build_roa_tree_from_db() else: self._build_roa_tree_from_roa_objs(roas) - def validate_all_routes(self, sources: Optional[List[str]]=None) -> \ - Tuple[List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]]: + def validate_all_routes( + self, sources: Optional[List[str]] = None + ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]]: """ Validate all RPSL route/route6 objects. @@ -85,10 +87,9 @@ def validate_all_routes(self, sources: Optional[List[str]]=None) -> \ Routes where their current validation status in the DB matches the new validation result, are not included in the return value. """ - columns = ['pk', 'rpsl_pk', 'ip_first', 'prefix_length', 'asn_first', 'source', - 'rpki_status'] + columns = ["pk", "rpsl_pk", "ip_first", "prefix_length", "asn_first", "source", "rpki_status"] q = RPSLDatabaseQuery(column_names=columns, enable_ordering=False) - q = q.object_classes(['route', 'route6']) + q = q.object_classes(["route", "route6"]) if sources: q = q.sources(sources) routes = self.database_handler.execute_query(q) @@ -97,27 +98,35 @@ def validate_all_routes(self, sources: Optional[List[str]]=None) -> \ for result in routes: # RPKI_IRR_PSEUDO_SOURCE objects are ROAs, and don't need validation. - if result['source'] == RPKI_IRR_PSEUDO_SOURCE: + if result["source"] == RPKI_IRR_PSEUDO_SOURCE: continue - current_status = result['rpki_status'] - result['old_status'] = current_status - new_status = self.validate_route(result['ip_first'], result['prefix_length'], - result['asn_first'], result['source']) + current_status = result["rpki_status"] + result["old_status"] = current_status + new_status = self.validate_route( + result["ip_first"], result["prefix_length"], result["asn_first"], result["source"] + ) if new_status != current_status: - result['rpki_status'] = new_status + result["rpki_status"] = new_status objs_changed[new_status].append(result) # Object text and class are only retrieved for objects with state changes - pks_to_enrich = [obj['pk'] for objs in objs_changed.values() for obj in objs] - query = RPSLDatabaseQuery(['pk', 'prefix', 'object_text', 'object_class', 'scopefilter_status', 'route_preference_status'], enable_ordering=False).pks(pks_to_enrich) - rows_per_pk = {row['pk']: row for row in self.database_handler.execute_query(query)} + pks_to_enrich = [obj["pk"] for objs in objs_changed.values() for obj in objs] + query = RPSLDatabaseQuery( + ["pk", "prefix", "object_text", "object_class", "scopefilter_status", "route_preference_status"], + enable_ordering=False, + ).pks(pks_to_enrich) + rows_per_pk = {row["pk"]: row for row in self.database_handler.execute_query(query)} for rpsl_objs in objs_changed.values(): for rpsl_obj in rpsl_objs: - rpsl_obj.update(rows_per_pk[rpsl_obj['pk']]) + rpsl_obj.update(rows_per_pk[rpsl_obj["pk"]]) - return objs_changed[RPKIStatus.valid], objs_changed[RPKIStatus.invalid], objs_changed[RPKIStatus.not_found] + return ( + objs_changed[RPKIStatus.valid], + objs_changed[RPKIStatus.invalid], + objs_changed[RPKIStatus.not_found], + ) def validate_route(self, prefix_ip: str, prefix_length: int, prefix_asn: int, source: str) -> RPKIStatus: """ @@ -149,7 +158,7 @@ def _build_roa_tree_from_roa_objs(self, roas: List[ROA]): """ for roa in roas: roa_tree = self.roa_tree6 if roa.prefix.version() == 6 else self.roa_tree4 - key = roa.prefix.strBin()[:roa.prefix.prefixlen()] + key = roa.prefix.strBin()[: roa.prefix.prefixlen()] if key in roa_tree: roa_tree[key].append((roa.prefix_str, roa.asn, roa.max_length)) else: @@ -161,14 +170,14 @@ def _build_roa_tree_from_db(self): """ roas = self.database_handler.execute_query(ROADatabaseObjectQuery()) for roa in roas: - first_ip, length = roa['prefix'].split('/') + first_ip, length = roa["prefix"].split("/") ip_version, ip_bin_str = self._ip_to_binary_str(first_ip) - key = ip_bin_str[:int(length)] + key = ip_bin_str[: int(length)] roa_tree = self.roa_tree6 if ip_version == 6 else self.roa_tree4 if key in roa_tree: - roa_tree[key].append((roa['prefix'], roa['asn'], roa['max_length'])) + roa_tree[key].append((roa["prefix"], roa["asn"], roa["max_length"])) else: - roa_tree[key] = [(roa['prefix'], roa['asn'], roa['max_length'])] + roa_tree[key] = [(roa["prefix"], roa["asn"], roa["max_length"])] def _ip_to_binary_str(self, ip: str) -> Tuple[int, str]: """ @@ -176,9 +185,9 @@ def _ip_to_binary_str(self, ip: str) -> Tuple[int, str]: 192.0.2.139 to 11000000000000000000001010001011 and return the IP version. """ - address_family = socket.AF_INET6 if ':' in ip else socket.AF_INET + address_family = socket.AF_INET6 if ":" in ip else socket.AF_INET ip_bin = socket.inet_pton(address_family, ip) - ip_bin_str = ''.join([BYTE_BIN[b] for b in ip_bin]) + '0' + ip_bin_str = "".join([BYTE_BIN[b] for b in ip_bin]) + "0" ip_version = 6 if address_family == socket.AF_INET6 else 4 return ip_version, ip_bin_str @@ -191,7 +200,7 @@ def validate_route(self, route: IP, asn: int, source: str) -> RPKIStatus: """ Validate a route from a particular source. """ - if get_setting(f'sources.{source}.rpki_excluded'): + if get_setting(f"sources.{source}.rpki_excluded"): return RPKIStatus.not_found query = ROADatabaseObjectQuery().ip_less_specific_or_exact(route) @@ -199,6 +208,6 @@ def validate_route(self, route: IP, asn: int, source: str) -> RPKIStatus: if not roas_covering: return RPKIStatus.not_found for roa in roas_covering: - if roa['asn'] != 0 and roa['asn'] == asn and route.prefixlen() <= roa['max_length']: + if roa["asn"] != 0 and roa["asn"] == asn and route.prefixlen() <= roa["max_length"]: return RPKIStatus.valid return RPKIStatus.invalid diff --git a/irrd/rpsl/fields.py b/irrd/rpsl/fields.py index 26c86d2fc..cc1502069 100644 --- a/irrd/rpsl/fields.py +++ b/irrd/rpsl/fields.py @@ -1,14 +1,15 @@ import datetime import re -from typing import List, Type, Optional +from typing import List, Optional, Type from urllib.parse import urlparse from IPy import IP from irrd.utils.text import clean_ip_value_error -from irrd.utils.validators import parse_as_number, ValidationError +from irrd.utils.validators import ValidationError, parse_as_number + +from .parser_state import RPSLFieldParseResult, RPSLParserMessages from .passwords import get_password_hashers -from .parser_state import RPSLParserMessages, RPSLFieldParseResult # The IPv4/IPv6 regexes are for initial screening - not full validators @@ -18,14 +19,39 @@ # This regex is not designed to catch every possible invalid variation, # but rather meant to protect against unintentional mistakes. # # Validate local-part @ domain | or IPv4 address | or IPv6 -re_email = re.compile(r"^[A-Z0-9$!#%&\"*+\/=?^_`{|}~\\.-]+@(([A-Z0-9\\.-]+)|(\[\d+\.\d+\.\d+\.\d+\])|(\[[A-f\d:]+\]))$", re.IGNORECASE) +re_email = re.compile( + r"^[A-Z0-9$!#%&\"*+\/=?^_`{|}~\\.-]+@(([A-Z0-9\\.-]+)|(\[\d+\.\d+\.\d+\.\d+\])|(\[[A-f\d:]+\]))$", + re.IGNORECASE, +) re_range_operator = re.compile(r"^(?P\d{1,3})-(?P\d{1,3})$|^(-)$|^(\+)$|^(?P\d{1,3})$") re_pgpkey = re.compile(r"^PGPKEY-[A-F0-9]{8}$") -re_dnsname = re.compile(r"^(([A-Z0-9]|[A-Z0-9][A-Z0-9\-]*[A-Z0-9])\.)*([A-Z0-9]|[A-Z0-9][A-Z0-9\-]*[A-Z0-9])$", re.IGNORECASE) +re_dnsname = re.compile( + r"^(([A-Z0-9]|[A-Z0-9][A-Z0-9\-]*[A-Z0-9])\.)*([A-Z0-9]|[A-Z0-9][A-Z0-9\-]*[A-Z0-9])$", re.IGNORECASE +) re_generic_name = re.compile(r"^[A-Z][A-Z0-9_-]*[A-Z0-9]$", re.IGNORECASE) -reserved_words = ["ANY", "AS-ANY", "RS_ANY", "PEERAS", "AND", "OR", "NOT", "ATOMIC", "FROM", "TO", "AT", "ACTION", - "ACCEPT", "ANNOUNCE", "EXCEPT", "REFINE", "NETWORKS", "INTO", "INBOUND", "OUTBOUND"] +reserved_words = [ + "ANY", + "AS-ANY", + "RS_ANY", + "PEERAS", + "AND", + "OR", + "NOT", + "ATOMIC", + "FROM", + "TO", + "AT", + "ACTION", + "ACCEPT", + "ANNOUNCE", + "EXCEPT", + "REFINE", + "NETWORKS", + "INTO", + "INBOUND", + "OUTBOUND", +] reserved_prefixes = ["AS-", "RS-", "RTRS-", "FLTR-", "PRNG-"] """ @@ -57,16 +83,25 @@ class RPSLTextField: by a field, other than the value. The parser only consider these extractions for primary key or lookup key fields. """ + keep_case = True extracts: List[str] = [] - def __init__(self, optional: bool=False, multiple: bool=False, primary_key: bool=False, lookup_key: bool=False) -> None: + def __init__( + self, + optional: bool = False, + multiple: bool = False, + primary_key: bool = False, + lookup_key: bool = False, + ) -> None: self.optional = optional self.multiple = multiple self.primary_key = primary_key self.lookup_key = lookup_key - def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: + def parse( + self, value: str, messages: RPSLParserMessages, strict_validation=True + ) -> Optional[RPSLFieldParseResult]: return RPSLFieldParseResult(value) @@ -81,9 +116,12 @@ class RPSLFieldListMixin: class RPSLASNumbersField(RPSLFieldListMixin, RPSLASNumberField): pass """ - def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: + + def parse( + self, value: str, messages: RPSLParserMessages, strict_validation=True + ) -> Optional[RPSLFieldParseResult]: parse_results = [] - for single_value in value.split(','): + for single_value in value.split(","): single_value = single_value.strip() if single_value: parse_result = super().parse(single_value, messages, strict_validation) # type: ignore @@ -91,66 +129,76 @@ def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True if not all(parse_results): return None values = [result.value for result in parse_results] - return RPSLFieldParseResult(','.join(values), values_list=values) + return RPSLFieldParseResult(",".join(values), values_list=values) class RPSLIPv4PrefixField(RPSLTextField): """Field for a single IPv4 prefix.""" - extracts = ['ip_first', 'ip_last', 'prefix', 'prefix_length'] - def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: + extracts = ["ip_first", "ip_last", "prefix", "prefix_length"] + + def parse( + self, value: str, messages: RPSLParserMessages, strict_validation=True + ) -> Optional[RPSLFieldParseResult]: if not re_ipv4_prefix.match(value): - messages.error(f'Invalid address prefix: {value}') + messages.error(f"Invalid address prefix: {value}") return None try: ip = IP(value, ipversion=4) except ValueError as ve: clean_error = clean_ip_value_error(ve) - messages.error(f'Invalid address prefix: {value}: {clean_error}') + messages.error(f"Invalid address prefix: {value}: {clean_error}") return None parsed_ip_str = str(ip) if ip.prefixlen() == 32: - parsed_ip_str += '/32' + parsed_ip_str += "/32" if parsed_ip_str != value: - messages.info(f'Address prefix {value} was reformatted as {parsed_ip_str}') - return RPSLFieldParseResult(parsed_ip_str, ip_first=ip.net(), ip_last=ip.broadcast(), - prefix=ip, prefix_length=ip.prefixlen()) + messages.info(f"Address prefix {value} was reformatted as {parsed_ip_str}") + return RPSLFieldParseResult( + parsed_ip_str, ip_first=ip.net(), ip_last=ip.broadcast(), prefix=ip, prefix_length=ip.prefixlen() + ) class RPSLIPv4PrefixesField(RPSLFieldListMixin, RPSLIPv4PrefixField): """Field for a comma-separated list of IPv4 prefixes.""" + pass class RPSLIPv6PrefixField(RPSLTextField): """Field for a single IPv6 prefix.""" - extracts = ['ip_first', 'ip_last', 'prefix', 'prefix_length'] - def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: + extracts = ["ip_first", "ip_last", "prefix", "prefix_length"] + + def parse( + self, value: str, messages: RPSLParserMessages, strict_validation=True + ) -> Optional[RPSLFieldParseResult]: if not re_ipv6_prefix.match(value): - messages.error(f'Invalid address prefix: {value}') + messages.error(f"Invalid address prefix: {value}") return None try: ip = IP(value, ipversion=6) except ValueError as ve: clean_error = clean_ip_value_error(ve) - messages.error(f'Invalid address prefix: {value}: {clean_error}') + messages.error(f"Invalid address prefix: {value}: {clean_error}") return None parsed_ip_str = str(ip) if ip.prefixlen() == 128: - parsed_ip_str += '/128' + parsed_ip_str += "/128" if parsed_ip_str != value: - messages.info(f'Address prefix {value} was reformatted as {parsed_ip_str}') - return RPSLFieldParseResult(parsed_ip_str, ip_first=ip.net(), ip_last=ip.broadcast(), - prefix=ip, prefix_length=ip.prefixlen()) + messages.info(f"Address prefix {value} was reformatted as {parsed_ip_str}") + return RPSLFieldParseResult( + parsed_ip_str, ip_first=ip.net(), ip_last=ip.broadcast(), prefix=ip, prefix_length=ip.prefixlen() + ) class RPSLIPv6PrefixesField(RPSLFieldListMixin, RPSLIPv6PrefixField): """Field for a comma-separated list of IPv6 prefixes.""" + pass @@ -161,12 +209,15 @@ class RPSLIPv4AddressRangeField(RPSLTextField): Note that a single IP address is also valid, and that the range does not have to align to bitwise boundaries of prefixes. """ - extracts = ['ip_first', 'ip_last'] - def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: - value = value.replace(',', '') # #311, process multiline PK correctly - if '-' in value: - ip1_input, ip2_input = value.split('-', 1) + extracts = ["ip_first", "ip_last"] + + def parse( + self, value: str, messages: RPSLParserMessages, strict_validation=True + ) -> Optional[RPSLFieldParseResult]: + value = value.replace(",", "") # #311, process multiline PK correctly + if "-" in value: + ip1_input, ip2_input = value.split("-", 1) else: ip1_input = ip2_input = value @@ -175,22 +226,22 @@ def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True ip2 = IP(ip2_input) except ValueError as ve: clean_error = clean_ip_value_error(ve) - messages.error(f'Invalid address range: {value}: {clean_error}') + messages.error(f"Invalid address range: {value}: {clean_error}") return None if not ip1.version() == ip2.version() == 4: - messages.error(f'Invalid address range: {value}: IP version mismatch') + messages.error(f"Invalid address range: {value}: IP version mismatch") return None if ip1.int() > ip2.int(): - messages.error(f'Invalid address range: {value}: first IP is higher than second IP') + messages.error(f"Invalid address range: {value}: first IP is higher than second IP") return None - if '-' in value: - parsed_value = f'{ip1} - {ip2}' + if "-" in value: + parsed_value = f"{ip1} - {ip2}" else: parsed_value = str(ip1) if parsed_value != value: - messages.info(f'Address range {value} was reformatted as {parsed_value}') + messages.info(f"Address range {value} was reformatted as {parsed_value}") return RPSLFieldParseResult(parsed_value, ip_first=ip1, ip_last=ip2) @@ -205,40 +256,45 @@ class RPSLRouteSetMemberField(RPSLTextField): - ^[integer] - ^[integer]-[integer] """ + keep_case = True def __init__(self, ip_version: Optional[int], *args, **kwargs) -> None: if ip_version and ip_version not in [4, 6]: - raise ValueError(f'Invalid IP version: {ip_version}') + raise ValueError(f"Invalid IP version: {ip_version}") self.ip_version = ip_version super().__init__(*args, **kwargs) - def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: - if '^' in value: - address, range_operator = value.split('^', maxsplit=1) + def parse( + self, value: str, messages: RPSLParserMessages, strict_validation=True + ) -> Optional[RPSLFieldParseResult]: + if "^" in value: + address, range_operator = value.split("^", maxsplit=1) if not range_operator: - messages.error(f'Missing range operator in value: {value}') + messages.error(f"Missing range operator in value: {value}") return None else: address = value - range_operator = '' + range_operator = "" parse_set_result_messages = RPSLParserMessages() - parse_set_result = parse_set_name(['RS-', 'AS-'], address, parse_set_result_messages, strict_validation) + parse_set_result = parse_set_name( + ["RS-", "AS-"], address, parse_set_result_messages, strict_validation + ) if parse_set_result and not parse_set_result_messages.errors(): result_value = parse_set_result.value if range_operator: - result_value += '^' + range_operator + result_value += "^" + range_operator if result_value != value: - messages.info(f'Route set member {value} was reformatted as {result_value}') + messages.info(f"Route set member {value} was reformatted as {result_value}") return RPSLFieldParseResult(value=result_value) try: parsed_str, parsed_int = parse_as_number(address) result_value = parsed_str if range_operator: - result_value += '^' + range_operator + result_value += "^" + range_operator if result_value != value: - messages.info(f'Route set member {value} was reformatted as {result_value}') + messages.info(f"Route set member {value} was reformatted as {result_value}") return RPSLFieldParseResult(value=result_value) except ValidationError: pass @@ -248,48 +304,48 @@ def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True ip = IP(address, ipversion=ip_version) except ValueError as ve: clean_error = clean_ip_value_error(ve) - messages.error(f'Value is neither a valid set name nor a valid prefix: {address}: {clean_error}') + messages.error(f"Value is neither a valid set name nor a valid prefix: {address}: {clean_error}") return None if range_operator: range_operator_match = re_range_operator.match(range_operator) if not range_operator_match: - messages.error(f'Invalid range operator {range_operator} in value: {value}') + messages.error(f"Invalid range operator {range_operator} in value: {value}") return None - single_range = range_operator_match.group('single') + single_range = range_operator_match.group("single") if single_range and int(single_range) < ip.prefixlen(): messages.error( - f'Invalid range operator: operator length ({single_range}) must be equal ' - f'to or longer than prefix length ({ip.prefixlen()}) {value}' + f"Invalid range operator: operator length ({single_range}) must be equal " + f"to or longer than prefix length ({ip.prefixlen()}) {value}" ) return None - start_range = range_operator_match.group('start') - end_range = range_operator_match.group('end') + start_range = range_operator_match.group("start") + end_range = range_operator_match.group("end") if start_range and int(start_range) < ip.prefixlen(): messages.error( - f'Invalid range operator: operator start ({start_range}) must be equal ' - f'to or longer than prefix length ({ip.prefixlen()}) {value}' + f"Invalid range operator: operator start ({start_range}) must be equal " + f"to or longer than prefix length ({ip.prefixlen()}) {value}" ) return None if end_range and int(end_range) < int(start_range): messages.error( - f'Invalid range operator: operator end ({end_range}) must be equal ' - f'to or longer than operator start ({start_range}) {value}' + f"Invalid range operator: operator end ({end_range}) must be equal " + f"to or longer than operator start ({start_range}) {value}" ) return None parsed_ip_str = str(ip) if ip.version() == 4 and ip.prefixlen() == 32: - parsed_ip_str += '/32' + parsed_ip_str += "/32" if ip.version() == 6 and ip.prefixlen() == 128: - parsed_ip_str += '/128' + parsed_ip_str += "/128" if range_operator: - parsed_ip_str += '^' + range_operator + parsed_ip_str += "^" + range_operator if parsed_ip_str != value: - messages.info(f'Route set member {value} was reformatted as {parsed_ip_str}') + messages.info(f"Route set member {value} was reformatted as {parsed_ip_str}") return RPSLFieldParseResult(parsed_ip_str) @@ -299,29 +355,35 @@ class RPSLRouteSetMembersField(RPSLFieldListMixin, RPSLRouteSetMemberField): class RPSLASNumberField(RPSLTextField): """Field for a single AS number (in ASxxxx syntax).""" - extracts = ['asn'] - def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: + extracts = ["asn"] + + def parse( + self, value: str, messages: RPSLParserMessages, strict_validation=True + ) -> Optional[RPSLFieldParseResult]: try: parsed_str, parsed_int = parse_as_number(value) except ValidationError as ve: messages.error(str(ve)) return None if parsed_str and parsed_str.upper() != value.upper(): - messages.info(f'AS number {value} was reformatted as {parsed_str}') + messages.info(f"AS number {value} was reformatted as {parsed_str}") return RPSLFieldParseResult(parsed_str, asn_first=parsed_int, asn_last=parsed_int) class RPSLASBlockField(RPSLTextField): """Field for a block of AS numbers, e.g. AS1 - AS5.""" - extracts = ['asn_first', 'asn_last'] - def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: - if '-' not in value: - messages.error(f'Invalid AS range: {value}: does not contain a hyphen') + extracts = ["asn_first", "asn_last"] + + def parse( + self, value: str, messages: RPSLParserMessages, strict_validation=True + ) -> Optional[RPSLFieldParseResult]: + if "-" not in value: + messages.error(f"Invalid AS range: {value}: does not contain a hyphen") return None - as1_raw, as2_raw = map(str.strip, value.split('-', 1)) + as1_raw, as2_raw = map(str.strip, value.split("-", 1)) try: as1_str, as1_int = parse_as_number(as1_raw) @@ -331,12 +393,12 @@ def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True return None if as1_int > as2_int: # type: ignore - messages.error(f'Invalid AS range: {value}: first AS is higher then second AS') + messages.error(f"Invalid AS range: {value}: first AS is higher then second AS") return None - parsed_value = f'{as1_str} - {as2_str}' + parsed_value = f"{as1_str} - {as2_str}" if parsed_value != value: - messages.info(f'AS range {value} was reformatted as {parsed_value}') + messages.info(f"AS range {value} was reformatted as {parsed_value}") return RPSLFieldParseResult(parsed_value, asn_first=as1_int, asn_last=as2_int) @@ -353,64 +415,79 @@ class RPSLSetNameField(RPSLTextField): The prefix provided is the expected prefix of the set name, e.g. 'RS' for a route-set, or 'AS' for an as-set. """ + keep_case = False def __init__(self, prefix: str, *args, **kwargs) -> None: - self.prefix = prefix + '-' + self.prefix = prefix + "-" super().__init__(*args, **kwargs) - def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: + def parse( + self, value: str, messages: RPSLParserMessages, strict_validation=True + ) -> Optional[RPSLFieldParseResult]: return parse_set_name([self.prefix], value, messages, strict_validation) class RPSLEmailField(RPSLTextField): """Field for an e-mail address. Only performs basic validation.""" - def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: + + def parse( + self, value: str, messages: RPSLParserMessages, strict_validation=True + ) -> Optional[RPSLFieldParseResult]: if not re_email.match(value): - messages.error(f'Invalid e-mail address: {value}') + messages.error(f"Invalid e-mail address: {value}") return None return RPSLFieldParseResult(value) class RPSLChangedField(RPSLTextField): """Field for an changed line. Only performs basic validation for email.""" - def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: + + def parse( + self, value: str, messages: RPSLParserMessages, strict_validation=True + ) -> Optional[RPSLFieldParseResult]: date: Optional[str] try: - email, date = value.split(' ') + email, date = value.split(" ") except ValueError: email = value date = None if not re_email.match(email): - messages.error(f'Invalid e-mail address: {email}') + messages.error(f"Invalid e-mail address: {email}") return None if date: try: - datetime.datetime.strptime(date, '%Y%m%d') + datetime.datetime.strptime(date, "%Y%m%d") except ValueError as ve: - messages.error(f'Invalid changed date: {date}: {ve}') + messages.error(f"Invalid changed date: {date}: {ve}") return None return RPSLFieldParseResult(value) class RPSLDNSNameField(RPSLTextField): """Field for a DNS name, as used in e.g. inet-rtr names.""" + keep_case = False - def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: + def parse( + self, value: str, messages: RPSLParserMessages, strict_validation=True + ) -> Optional[RPSLFieldParseResult]: if not re_dnsname.match(value): - messages.error(f'Invalid DNS name: {value}') + messages.error(f"Invalid DNS name: {value}") return None return RPSLFieldParseResult(value) class RPSLURLField(RPSLTextField): """Field for a URL, as used in e.g. geofeed attribute.""" + keep_case = False - permitted_schemes = ['http', 'https'] + permitted_schemes = ["http", "https"] - def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: + def parse( + self, value: str, messages: RPSLParserMessages, strict_validation=True + ) -> Optional[RPSLFieldParseResult]: result = urlparse(value) if all([result.scheme in self.permitted_schemes, result.netloc]): return RPSLFieldParseResult(value) @@ -432,34 +509,41 @@ class RPSLGenericNameField(RPSLTextField): is disabled. This is needed on nic-hdl for legacy reasons - see https://github.com/irrdnet/irrd/issues/60 """ + keep_case = False - def __init__(self, allowed_prefixes: Optional[List[str]]=None, non_strict_allow_any=False, *args, **kwargs) -> None: + def __init__( + self, allowed_prefixes: Optional[List[str]] = None, non_strict_allow_any=False, *args, **kwargs + ) -> None: self.non_strict_allow_any = non_strict_allow_any if allowed_prefixes: - self.allowed_prefixes = [prefix.upper() + '-' for prefix in allowed_prefixes] + self.allowed_prefixes = [prefix.upper() + "-" for prefix in allowed_prefixes] else: self.allowed_prefixes = [] super().__init__(*args, **kwargs) - def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: + def parse( + self, value: str, messages: RPSLParserMessages, strict_validation=True + ) -> Optional[RPSLFieldParseResult]: if not strict_validation and self.non_strict_allow_any: return RPSLFieldParseResult(value) if strict_validation: upper_value = value.upper() if upper_value in reserved_words: - messages.error(f'Invalid name: {value}: this is a reserved word') + messages.error(f"Invalid name: {value}: this is a reserved word") return None for prefix in reserved_prefixes: if upper_value.startswith(prefix) and prefix not in self.allowed_prefixes: - messages.error(f'Invalid name: {value}: {prefix} is a reserved prefix') + messages.error(f"Invalid name: {value}: {prefix} is a reserved prefix") return None if not re_generic_name.match(value): - messages.error(f'Invalid name: {value}: contains invalid characters, does not start with a letter, ' - f'or does not end in a letter/digit') + messages.error( + f"Invalid name: {value}: contains invalid characters, does not start with a letter, " + "or does not end in a letter/digit" + ) return None return RPSLFieldParseResult(value) @@ -480,17 +564,21 @@ class RPSLReferenceField(RPSLTextField): on updates, i.e. adding an object with a strong reference to another object that does not exist, is a validation failure. """ + keep_case = False def __init__(self, referring: List[str], strong=True, *args, **kwargs) -> None: from .parser import RPSLObject + self.referring = referring self.strong = strong self.referring_object_classes: List[Type[RPSLObject]] = [] self.referring_identifier_fields: List[RPSLTextField] = [] super().__init__(*args, **kwargs) - def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: + def parse( + self, value: str, messages: RPSLParserMessages, strict_validation=True + ) -> Optional[RPSLFieldParseResult]: if not self.referring_identifier_fields: self.resolve_references() @@ -506,9 +594,12 @@ def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True def resolve_references(self): from .rpsl_objects import OBJECT_CLASS_MAPPING + for ref in self.referring: rpsl_object_class = OBJECT_CLASS_MAPPING[ref] - pk_field = [field for field in rpsl_object_class.fields.values() if field.primary_key and field.lookup_key][0] + pk_field = [ + field for field in rpsl_object_class.fields.values() if field.primary_key and field.lookup_key + ][0] self.referring_object_classes.append(rpsl_object_class) self.referring_identifier_fields.append(pk_field) @@ -519,49 +610,61 @@ class RPSLReferenceListField(RPSLFieldListMixin, RPSLReferenceField): Optionally, ANY can be allowed as a valid option too, instead of a list. """ - def __init__(self, allow_kw_any: bool=False, *args, **kwargs) -> None: + + def __init__(self, allow_kw_any: bool = False, *args, **kwargs) -> None: self.allow_kw_any = allow_kw_any super().__init__(*args, **kwargs) - def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: - if self.allow_kw_any and value.upper() == 'ANY': - return RPSLFieldParseResult('ANY', values_list=['ANY']) + def parse( + self, value: str, messages: RPSLParserMessages, strict_validation=True + ) -> Optional[RPSLFieldParseResult]: + if self.allow_kw_any and value.upper() == "ANY": + return RPSLFieldParseResult("ANY", values_list=["ANY"]) return super().parse(value, messages, strict_validation) class RPSLAuthField(RPSLTextField): """Field for the auth attribute of a mntner.""" - def parse(self, value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: + + def parse( + self, value: str, messages: RPSLParserMessages, strict_validation=True + ) -> Optional[RPSLFieldParseResult]: hashers = get_password_hashers(permit_legacy=not strict_validation) - valid_beginnings = [hasher + ' ' for hasher in hashers.keys()] + valid_beginnings = [hasher + " " for hasher in hashers.keys()] has_valid_beginning = any(value.upper().startswith(b) for b in valid_beginnings) - is_valid_hash = has_valid_beginning and value.count(' ') == 1 and not value.count(',') + is_valid_hash = has_valid_beginning and value.count(" ") == 1 and not value.count(",") if is_valid_hash or re_pgpkey.match(value.upper()): return RPSLFieldParseResult(value) - hashers = ', '.join(hashers.keys()) - messages.error(f'Invalid auth attribute: {value}: supported options are {hashers} and PGPKEY-xxxxxxxx') + hashers = ", ".join(hashers.keys()) + messages.error( + f"Invalid auth attribute: {value}: supported options are {hashers} and PGPKEY-xxxxxxxx" + ) return None -def parse_set_name(prefixes: List[str], value: str, messages: RPSLParserMessages, strict_validation=True) -> Optional[RPSLFieldParseResult]: +def parse_set_name( + prefixes: List[str], value: str, messages: RPSLParserMessages, strict_validation=True +) -> Optional[RPSLFieldParseResult]: assert all([prefix in reserved_prefixes for prefix in prefixes]) - input_components = value.split(':') + input_components = value.split(":") output_components: List[str] = [] - prefix_display = '/'.join(prefixes) + prefix_display = "/".join(prefixes) if strict_validation and len(input_components) > 5: - messages.error('Set names can have a maximum of five components.') + messages.error("Set names can have a maximum of five components.") return None if strict_validation and not any([c.upper().startswith(tuple(prefixes)) for c in input_components]): - messages.error(f'Invalid set name {value}: at least one component must be ' - f'an actual set name (i.e. start with {prefix_display})') + messages.error( + f"Invalid set name {value}: at least one component must be " + f"an actual set name (i.e. start with {prefix_display})" + ) return None for component in input_components: if strict_validation and component.upper() in reserved_words: - messages.error(f'Invalid set name {value}: component {component} is a reserved word') + messages.error(f"Invalid set name {value}: component {component} is a reserved word") return None parsed_as_number = None @@ -571,12 +674,14 @@ def parse_set_name(prefixes: List[str], value: str, messages: RPSLParserMessages pass if not re_generic_name.match(component.upper()) and not parsed_as_number: messages.error( - f'Invalid set {value}: component {component} is not a valid AS number nor a valid set name' + f"Invalid set {value}: component {component} is not a valid AS number nor a valid set name" ) return None if strict_validation and not parsed_as_number and not component.upper().startswith(tuple(prefixes)): - messages.error(f'Invalid set {value}: component {component} is not a valid AS number, ' - f'nor does it start with {prefix_display}') + messages.error( + f"Invalid set {value}: component {component} is not a valid AS number, " + f"nor does it start with {prefix_display}" + ) return None if parsed_as_number: @@ -584,7 +689,7 @@ def parse_set_name(prefixes: List[str], value: str, messages: RPSLParserMessages else: output_components.append(component) - parsed_value = ':'.join(output_components) + parsed_value = ":".join(output_components) if parsed_value != value: - messages.info(f'Set name {value} was reformatted as {parsed_value}') + messages.info(f"Set name {value} was reformatted as {parsed_value}") return RPSLFieldParseResult(parsed_value) diff --git a/irrd/rpsl/parser.py b/irrd/rpsl/parser.py index 89861116a..a5541304a 100644 --- a/irrd/rpsl/parser.py +++ b/irrd/rpsl/parser.py @@ -2,18 +2,19 @@ import itertools import json import re -from collections import OrderedDict, Counter -from typing import Dict, List, Optional, Tuple, Any, Set +from collections import Counter, OrderedDict +from typing import Any, Dict, List, Optional, Set, Tuple from IPy import IP -from irrd.rpki.status import RPKIStatus from irrd.routepref.status import RoutePreferenceStatus +from irrd.rpki.status import RPKIStatus from irrd.rpsl.parser_state import RPSLParserMessages from irrd.scopefilter.status import ScopeFilterStatus from irrd.utils.text import splitline_unicodesafe -from .fields import RPSLTextField + from ..conf import get_setting +from .fields import RPSLTextField RPSL_ATTRIBUTE_TEXT_WIDTH = 16 TypeRPSLObjectData = List[Tuple[str, str, List[str]]] @@ -27,9 +28,10 @@ class RPSLObjectMeta(type): kept as small as possible. This metaclass pre-calculates some derived data from the fields defined by a subclass of RPSLObject, for optimised parsing speed. """ + def __init__(cls, name, bases, clsdict): # noqa: N805 super().__init__(name, bases, clsdict) - fields = clsdict.get('fields') + fields = clsdict.get("fields") if fields: cls.rpsl_object_class = list(fields.keys())[0] cls.pk_fields = [field[0] for field in fields.items() if field[1].primary_key] @@ -37,10 +39,20 @@ def __init__(cls, name, bases, clsdict): # noqa: N805 cls.attrs_allowed = [field[0] for field in fields.items()] cls.attrs_required = [field[0] for field in fields.items() if not field[1].optional] cls.attrs_multiple = [field[0] for field in fields.items() if field[1].multiple] - cls.field_extracts = list(itertools.chain( - *[field[1].extracts for field in fields.items() if field[1].primary_key or field[1].lookup_key] - )) - cls.referring_strong_fields = [(field[0], field[1].referring) for field in fields.items() if hasattr(field[1], 'referring') and getattr(field[1], 'strong')] + cls.field_extracts = list( + itertools.chain( + *[ + field[1].extracts + for field in fields.items() + if field[1].primary_key or field[1].lookup_key + ] + ) + ) + cls.referring_strong_fields = [ + (field[0], field[1].referring) + for field in fields.items() + if hasattr(field[1], "referring") and getattr(field[1], "strong") + ] class RPSLObject(metaclass=RPSLObjectMeta): @@ -55,6 +67,7 @@ class RPSLObject(metaclass=RPSLObjectMeta): made for each RPSL type with the appropriate fields defined. Note that any subclasses should also be added to OBJECT_CLASS_MAPPING. """ + fields: Dict[str, RPSLTextField] = OrderedDict() rpsl_object_class: str pk_fields: List[str] = [] @@ -71,7 +84,7 @@ class RPSLObject(metaclass=RPSLObjectMeta): scopefilter_status: ScopeFilterStatus = ScopeFilterStatus.in_scope route_preference_status: RoutePreferenceStatus = RoutePreferenceStatus.visible pk_asn_segment: Optional[str] = None - default_source: Optional[str] = None # noqa: E704 (flake8 bug) + default_source: Optional[str] = None # Shortcut for whether this object is a route-like object, and therefore # should be included in RPKI and route preference status. Enabled for route/route6. is_route = False @@ -79,11 +92,11 @@ class RPSLObject(metaclass=RPSLObjectMeta): discarded_fields: List[str] = [] # Fields that are ignored in validation even # for authoritative objects (see #587 for example). - ignored_validation_fields: List[str] = ['last-modified'] + ignored_validation_fields: List[str] = ["last-modified"] - _re_attr_name = re.compile(r'^[a-z0-9_-]+$') + _re_attr_name = re.compile(r"^[a-z0-9_-]+$") - def __init__(self, from_text: Optional[str]=None, strict_validation=True, default_source=None) -> None: + def __init__(self, from_text: Optional[str] = None, strict_validation=True, default_source=None) -> None: """ Create a new RPSL object, optionally instantiated from a string. @@ -113,14 +126,14 @@ def pk(self) -> str: composite_values = [] for field in self.pk_fields: composite_values.append(self.parsed_data.get(field, "")) - return ''.join(composite_values).upper() + return "".join(composite_values).upper() def source(self) -> str: """Shortcut to retrieve object source""" try: - return self.parsed_data['source'] + return self.parsed_data["source"] except KeyError: - raise ValueError('RPSL object has no known source') + raise ValueError("RPSL object has no known source") def ip_version(self) -> Optional[int]: """ @@ -161,13 +174,14 @@ def references_strong_inbound(self) -> Set[str]: """ result = set() from irrd.rpsl.rpsl_objects import OBJECT_CLASS_MAPPING + for rpsl_object in OBJECT_CLASS_MAPPING.values(): for field_name, field in rpsl_object.fields.items(): - if self.rpsl_object_class in getattr(field, 'referring', []) and getattr(field, 'strong'): + if self.rpsl_object_class in getattr(field, "referring", []) and getattr(field, "strong"): result.add(field_name) return result - def render_rpsl_text(self, last_modified: Optional[datetime.datetime]=None) -> str: + def render_rpsl_text(self, last_modified: Optional[datetime.datetime] = None) -> str: """ Render the RPSL object as an RPSL string. If last_modified is provided, removes existing last-modified: @@ -175,14 +189,14 @@ def render_rpsl_text(self, last_modified: Optional[datetime.datetime]=None) -> s is authoritative. """ output = "" - authoritative = get_setting(f'sources.{self.source()}.authoritative') + authoritative = get_setting(f"sources.{self.source()}.authoritative") for attr, value, continuation_chars in self._object_data: - if authoritative and last_modified and attr == 'last-modified': + if authoritative and last_modified and attr == "last-modified": continue - attr_display = f'{attr}:'.ljust(RPSL_ATTRIBUTE_TEXT_WIDTH) + attr_display = f"{attr}:".ljust(RPSL_ATTRIBUTE_TEXT_WIDTH) value_lines = list(splitline_unicodesafe(value)) if not value_lines: - output += f'{attr}:\n' + output += f"{attr}:\n" for idx, line in enumerate(value_lines): if idx == 0: output += attr_display + line @@ -190,13 +204,13 @@ def render_rpsl_text(self, last_modified: Optional[datetime.datetime]=None) -> s continuation_char = continuation_chars[idx - 1] # Override the continuation char for empty lines #298 if not line: - continuation_char = '+' - output += continuation_char + (RPSL_ATTRIBUTE_TEXT_WIDTH - 1) * ' ' + line - output += '\n' + continuation_char = "+" + output += continuation_char + (RPSL_ATTRIBUTE_TEXT_WIDTH - 1) * " " + line + output += "\n" if authoritative and last_modified: - output += 'last-modified:'.ljust(RPSL_ATTRIBUTE_TEXT_WIDTH) - output += last_modified.replace(microsecond=0).isoformat().replace('+00:00', 'Z') - output += '\n' + output += "last-modified:".ljust(RPSL_ATTRIBUTE_TEXT_WIDTH) + output += last_modified.replace(microsecond=0).isoformat().replace("+00:00", "Z") + output += "\n" return output def generate_template(self): @@ -204,21 +218,21 @@ def generate_template(self): template = "" max_name_width = max(len(k) for k in self.fields.keys()) for name, field in self.fields.items(): - mandatory = '[optional] ' if field.optional else '[mandatory]' - single = '[multiple]' if field.multiple else '[single] ' + mandatory = "[optional] " if field.optional else "[mandatory]" + single = "[multiple]" if field.multiple else "[single] " metadata = [] if field.primary_key and field.lookup_key: - metadata.append('primary/look-up key') + metadata.append("primary/look-up key") elif field.primary_key: - metadata.append('primary key') + metadata.append("primary key") elif field.lookup_key: - metadata.append('look-up key') - if getattr(field, 'referring', []): - reference_type = 'strong' if getattr(field, 'strong') else 'weak' - metadata.append(f'{reference_type} references ' + '/'.join(field.referring)) - metadata_str = ', '.join(metadata) - name_padding = (max_name_width - len(name)) * ' ' - template += f'{name}: {name_padding} {mandatory} {single} [{metadata_str}]\n' + metadata.append("look-up key") + if getattr(field, "referring", []): + reference_type = "strong" if getattr(field, "strong") else "weak" + metadata.append(f"{reference_type} references " + "/".join(field.referring)) + metadata_str = ", ".join(metadata) + name_padding = (max_name_width - len(name)) * " " + template += f"{name}: {name_padding} {mandatory} {single} [{metadata_str}]\n" return template def clean(self) -> bool: @@ -246,14 +260,16 @@ def _extract_attributes_values(self, text: str) -> None: attribute value, and the continuation characters. The continuation characters are needed to reconstruct the original object into a string. """ - continuation_chars = (' ', '+', '\t') + continuation_chars = (" ", "+", "\t") current_attr = None current_value = "" current_continuation_chars: List[str] = [] for line_no, line in enumerate(splitline_unicodesafe(text.strip())): if not line: - self.messages.error(f'Line {line_no+1}: encountered empty line in the middle of object: [{line}]') + self.messages.error( + f"Line {line_no+1}: encountered empty line in the middle of object: [{line}]" + ) return if not line.startswith(continuation_chars): @@ -263,20 +279,24 @@ def _extract_attributes_values(self, text: str) -> None: # the attribute is finished. self._object_data.append((current_attr, current_value, current_continuation_chars)) - if ':' not in line: - self.messages.error(f'Line {line_no+1}: line is neither continuation nor valid attribute [{line}]') + if ":" not in line: + self.messages.error( + f"Line {line_no+1}: line is neither continuation nor valid attribute [{line}]" + ) return - current_attr, current_value = line.split(':', maxsplit=1) + current_attr, current_value = line.split(":", maxsplit=1) current_attr = current_attr.lower() current_value = current_value.strip() current_continuation_chars = [] if current_attr not in self.attrs_allowed and not self._re_attr_name.match(current_attr): - self.messages.error(f'Line {line_no+1}: encountered malformed attribute name: [{current_attr}]') + self.messages.error( + f"Line {line_no+1}: encountered malformed attribute name: [{current_attr}]" + ) return else: # Whitespace between the continuation character and the start of the data is not significant. - current_value += '\n' + line[1:].strip() + current_value += "\n" + line[1:].strip() current_continuation_chars += line[0] if current_attr and current_attr not in self.discarded_fields: self._object_data.append((current_attr, current_value, current_continuation_chars)) @@ -309,11 +329,14 @@ def _validate_attribute_counts(self) -> None: if attr_name in self.ignored_validation_fields: continue if attr_name not in self.attrs_allowed: - self.messages.error(f'Unrecognised attribute {attr_name} on object {self.rpsl_object_class}') + self.messages.error( + f"Unrecognised attribute {attr_name} on object {self.rpsl_object_class}" + ) if count > 1 and attr_name not in self.attrs_multiple: self.messages.error( - f'Attribute "{attr_name}" on object {self.rpsl_object_class} occurs multiple times, but is ' - f'only allowed once') + f'Attribute "{attr_name}" on object {self.rpsl_object_class} occurs multiple times,' + " but is only allowed once" + ) for attr_required in self.attrs_required: if attr_required not in attrs_present: self.messages.error( @@ -322,7 +345,7 @@ def _validate_attribute_counts(self) -> None: else: required_fields = self.pk_fields if not self.default_source: - required_fields = required_fields + ['source'] + required_fields = required_fields + ["source"] for attr_pk in required_fields: if attr_pk not in attrs_present: self.messages.error( @@ -353,7 +376,9 @@ def _parse_attribute_data(self, allow_invalid_metadata=False) -> None: # the source field. In all other cases, the field parsing is best effort. # In all these other cases we pass a new parser messages object to the # field parser, so that we basically discard any errors. - raise_errors = self.strict_validation or field.primary_key or field.lookup_key or attr_name == 'source' + raise_errors = ( + self.strict_validation or field.primary_key or field.lookup_key or attr_name == "source" + ) field_messages = self.messages if raise_errors else RPSLParserMessages() parsed_value = field.parse(normalised_value, field_messages, self.strict_validation) @@ -383,24 +408,26 @@ def _parse_attribute_data(self, allow_invalid_metadata=False) -> None: self.parsed_data[attr_name] = [parsed_value_str] else: if attr_name in self.parsed_data: - self.parsed_data[attr_name] = '\n' + parsed_value_str + self.parsed_data[attr_name] = "\n" + parsed_value_str else: self.parsed_data[attr_name] = parsed_value_str # Some fields provide additional metadata about the resources to # which this object pertains. if field.primary_key or field.lookup_key: - for attr in 'ip_first', 'ip_last', 'asn_first', 'asn_last', 'prefix', 'prefix_length': + for attr in "ip_first", "ip_last", "asn_first", "asn_last", "prefix", "prefix_length": attr_value = getattr(parsed_value, attr, None) if attr_value: existing_attr_value = getattr(self, attr, None) if existing_attr_value and not allow_invalid_metadata: # pragma: no cover - raise ValueError(f'Parsing of {parsed_value.value} reads {attr_value} for {attr},' - f'but value {existing_attr_value} is already set.') + raise ValueError( + f"Parsing of {parsed_value.value} reads {attr_value} for {attr}," + f"but value {existing_attr_value} is already set." + ) setattr(self, attr, attr_value) - if 'source' not in self.parsed_data and self.default_source: - self.parsed_data['source'] = self.default_source + if "source" not in self.parsed_data and self.default_source: + self.parsed_data["source"] = self.default_source def _normalise_rpsl_value(self, value: str) -> str: """ @@ -421,15 +448,15 @@ def _normalise_rpsl_value(self, value: str) -> str: normalized_lines = [] # The shortcuts below are functionally inconsequential, but significantly improve performance, # as most values are single line without comments, and this method is called extremely often. - if '\n' not in value: - if '#' in value: - return value.split('#')[0].strip() + if "\n" not in value: + if "#" in value: + return value.split("#")[0].strip() return value.strip() for line in splitline_unicodesafe(value): - parsed_line = line.split('#')[0].strip('\n\t, ') + parsed_line = line.split("#")[0].strip("\n\t, ") if parsed_line: normalized_lines.append(parsed_line) - return ','.join(normalized_lines) + return ",".join(normalized_lines) def _update_attribute_value(self, attribute, new_values): """ @@ -444,7 +471,7 @@ def _update_attribute_value(self, attribute, new_values): """ if isinstance(new_values, str): new_values = [new_values] - self.parsed_data[attribute] = '\n'.join(new_values) + self.parsed_data[attribute] = "\n".join(new_values) self._object_data = list(filter(lambda a: a[0] != attribute, self._object_data)) insert_idx = 1 @@ -453,8 +480,8 @@ def _update_attribute_value(self, attribute, new_values): insert_idx += 1 def __repr__(self): - source = self.parsed_data.get('source', '') - return f'{self.rpsl_object_class}/{self.pk()}/{source}' + source = self.parsed_data.get("source", "") + return f"{self.rpsl_object_class}/{self.pk()}/{source}" def __key(self): return self.rpsl_object_class, self.pk(), json.dumps(self.parsed_data, sort_keys=True) diff --git a/irrd/rpsl/parser_state.py b/irrd/rpsl/parser_state.py index be705c659..9988c7cf9 100644 --- a/irrd/rpsl/parser_state.py +++ b/irrd/rpsl/parser_state.py @@ -1,34 +1,34 @@ -from typing import Optional, TypeVar, List +from typing import List, Optional, TypeVar from IPy import IP -RPSLParserMessagesType = TypeVar('RPSLParserMessagesType', bound='RPSLParserMessages') +RPSLParserMessagesType = TypeVar("RPSLParserMessagesType", bound="RPSLParserMessages") class RPSLParserMessages: - levels = ['INFO', 'ERROR'] + levels = ["INFO", "ERROR"] def __init__(self) -> None: self._messages: List[tuple] = [] def __str__(self) -> str: - messages_str = [f'{msg[0]}: {msg[1]}' for msg in self._messages] - return '\n'.join(messages_str) + messages_str = [f"{msg[0]}: {msg[1]}" for msg in self._messages] + return "\n".join(messages_str) def messages(self) -> List[str]: return [msg[1] for msg in self._messages] def infos(self) -> List[str]: - return [msg[1] for msg in self._messages if msg[0] == 'INFO'] + return [msg[1] for msg in self._messages if msg[0] == "INFO"] def errors(self) -> List[str]: - return [msg[1] for msg in self._messages if msg[0] == 'ERROR'] + return [msg[1] for msg in self._messages if msg[0] == "ERROR"] def info(self, msg: str) -> None: - self._message('INFO', msg) + self._message("INFO", msg) def error(self, msg: str) -> None: - self._message('ERROR', msg) + self._message("ERROR", msg) def merge_messages(self, other_messages: RPSLParserMessagesType) -> None: self._messages += other_messages._messages @@ -38,8 +38,17 @@ def _message(self, level: str, message: str) -> None: class RPSLFieldParseResult: - def __init__(self, value: str, values_list: Optional[List[str]]=None, ip_first: Optional[IP]=None, ip_last: Optional[IP]=None, - prefix: Optional[IP]=None, prefix_length: Optional[int]=None, asn_first: Optional[int]=None, asn_last: Optional[int]=None) -> None: + def __init__( + self, + value: str, + values_list: Optional[List[str]] = None, + ip_first: Optional[IP] = None, + ip_last: Optional[IP] = None, + prefix: Optional[IP] = None, + prefix_length: Optional[int] = None, + asn_first: Optional[int] = None, + asn_last: Optional[int] = None, + ) -> None: self.value = value self.values_list = values_list self.ip_first = ip_first diff --git a/irrd/rpsl/passwords.py b/irrd/rpsl/passwords.py index 8c9933d42..481e3e9d8 100644 --- a/irrd/rpsl/passwords.py +++ b/irrd/rpsl/passwords.py @@ -1,19 +1,21 @@ -from enum import unique, Enum +from enum import Enum, unique + +from passlib.hash import bcrypt, des_crypt, md5_crypt + from irrd.conf import get_setting -from passlib.hash import des_crypt, md5_crypt, bcrypt @unique class PasswordHasherAvailability(Enum): - ENABLED = 'enabled' - LEGACY = 'legacy' - DISABLED = 'disabled' + ENABLED = "enabled" + LEGACY = "legacy" + DISABLED = "disabled" PASSWORD_HASHERS_ALL = { - 'CRYPT-PW': des_crypt, - 'MD5-PW': md5_crypt, - 'BCRYPT-PW': bcrypt, + "CRYPT-PW": des_crypt, + "MD5-PW": md5_crypt, + "BCRYPT-PW": bcrypt, } @@ -28,7 +30,7 @@ def get_password_hashers(permit_legacy=True): included_availabilities.add(PasswordHasherAvailability.LEGACY) for hasher_name, hasher_function in PASSWORD_HASHERS_ALL.items(): - setting = get_setting(f'auth.password_hashers.{hasher_name.lower()}') + setting = get_setting(f"auth.password_hashers.{hasher_name.lower()}") availability = getattr(PasswordHasherAvailability, setting.upper()) if availability in included_availabilities: hashers[hasher_name] = hasher_function @@ -36,4 +38,4 @@ def get_password_hashers(permit_legacy=True): return hashers -PASSWORD_REPLACEMENT_HASH = ('BCRYPT-PW', bcrypt) +PASSWORD_REPLACEMENT_HASH = ("BCRYPT-PW", bcrypt) diff --git a/irrd/rpsl/rpsl_objects.py b/irrd/rpsl/rpsl_objects.py index 8ced2e378..f0ec820c9 100644 --- a/irrd/rpsl/rpsl_objects.py +++ b/irrd/rpsl/rpsl_objects.py @@ -1,219 +1,335 @@ from collections import OrderedDict +from typing import List, Optional, Set, Union -from typing import Set, List, Optional, Union - -from irrd.conf import AUTH_SET_CREATION_COMMON_KEY, PASSWORD_HASH_DUMMY_VALUE, get_setting +from irrd.conf import ( + AUTH_SET_CREATION_COMMON_KEY, + PASSWORD_HASH_DUMMY_VALUE, + get_setting, +) from irrd.utils.pgp import get_gpg_instance -from .passwords import PASSWORD_REPLACEMENT_HASH, get_password_hashers -from .fields import (RPSLTextField, RPSLIPv4PrefixField, RPSLIPv4PrefixesField, RPSLIPv6PrefixField, - RPSLIPv6PrefixesField, RPSLIPv4AddressRangeField, RPSLASNumberField, - RPSLASBlockField, - RPSLSetNameField, RPSLEmailField, RPSLDNSNameField, RPSLGenericNameField, - RPSLReferenceField, - RPSLReferenceListField, RPSLAuthField, RPSLRouteSetMembersField, - RPSLChangedField, RPSLURLField) + +from ..utils.validators import ValidationError, parse_as_number +from .fields import ( + RPSLASBlockField, + RPSLASNumberField, + RPSLAuthField, + RPSLChangedField, + RPSLDNSNameField, + RPSLEmailField, + RPSLGenericNameField, + RPSLIPv4AddressRangeField, + RPSLIPv4PrefixesField, + RPSLIPv4PrefixField, + RPSLIPv6PrefixesField, + RPSLIPv6PrefixField, + RPSLReferenceField, + RPSLReferenceListField, + RPSLRouteSetMembersField, + RPSLSetNameField, + RPSLTextField, + RPSLURLField, +) from .parser import RPSLObject, UnknownRPSLObjectClassException -from ..utils.validators import parse_as_number, ValidationError +from .passwords import PASSWORD_REPLACEMENT_HASH, get_password_hashers RPSL_ROUTE_OBJECT_CLASS_FOR_IP_VERSION = { - 4: 'route', - 6: 'route6', + 4: "route", + 6: "route6", } -def rpsl_object_from_text(text, strict_validation=True, default_source: Optional[str]=None) -> RPSLObject: - rpsl_object_class = text.split(':', maxsplit=1)[0].strip() +def rpsl_object_from_text(text, strict_validation=True, default_source: Optional[str] = None) -> RPSLObject: + rpsl_object_class = text.split(":", maxsplit=1)[0].strip() try: klass = OBJECT_CLASS_MAPPING[rpsl_object_class] except KeyError: - raise UnknownRPSLObjectClassException(f'unknown object class: {rpsl_object_class}', - rpsl_object_class=rpsl_object_class) + raise UnknownRPSLObjectClassException( + f"unknown object class: {rpsl_object_class}", rpsl_object_class=rpsl_object_class + ) return klass(from_text=text, strict_validation=strict_validation, default_source=default_source) class RPSLSet(RPSLObject): def clean_for_create(self) -> bool: - self.pk_asn_segment = self.pk().split(':')[0] + self.pk_asn_segment = self.pk().split(":")[0] try: parse_as_number(self.pk_asn_segment) return True except ValidationError as ve: self.pk_asn_segment = None - if get_setting(f'auth.set_creation.{self.rpsl_object_class}.prefix_required') is False: + if get_setting(f"auth.set_creation.{self.rpsl_object_class}.prefix_required") is False: return True - if get_setting(f'auth.set_creation.{AUTH_SET_CREATION_COMMON_KEY}.prefix_required') is False: + if get_setting(f"auth.set_creation.{AUTH_SET_CREATION_COMMON_KEY}.prefix_required") is False: return True - self.messages.error(f'{self.rpsl_object_class} names must be hierarchical and the first ' - f'component must be an AS number, e.g. "AS65537:{self.pk_asn_segment}": {str(ve)}') + self.messages.error( + f"{self.rpsl_object_class} names must be hierarchical and the first " + f'component must be an AS number, e.g. "AS65537:{self.pk_asn_segment}": {str(ve)}' + ) return False class RPSLAsBlock(RPSLObject): - fields = OrderedDict([ - ('as-block', RPSLASBlockField(primary_key=True, lookup_key=True)), - ('descr', RPSLTextField(multiple=True, optional=True)), - ('admin-c', RPSLReferenceField(lookup_key=True, multiple=True, referring=['role', 'person'])), - ('tech-c', RPSLReferenceField(lookup_key=True, multiple=True, referring=['role', 'person'])), - ('remarks', RPSLTextField(optional=True, multiple=True)), - ('notify', RPSLEmailField(optional=True, multiple=True)), - ('mnt-by', RPSLReferenceListField(lookup_key=True, multiple=True, referring=['mntner'])), - ('changed', RPSLChangedField(optional=True, multiple=True)), - ('source', RPSLGenericNameField()), - ]) + fields = OrderedDict( + [ + ("as-block", RPSLASBlockField(primary_key=True, lookup_key=True)), + ("descr", RPSLTextField(multiple=True, optional=True)), + ("admin-c", RPSLReferenceField(lookup_key=True, multiple=True, referring=["role", "person"])), + ("tech-c", RPSLReferenceField(lookup_key=True, multiple=True, referring=["role", "person"])), + ("remarks", RPSLTextField(optional=True, multiple=True)), + ("notify", RPSLEmailField(optional=True, multiple=True)), + ("mnt-by", RPSLReferenceListField(lookup_key=True, multiple=True, referring=["mntner"])), + ("changed", RPSLChangedField(optional=True, multiple=True)), + ("source", RPSLGenericNameField()), + ] + ) class RPSLAsSet(RPSLSet): - fields = OrderedDict([ - ('as-set', RPSLSetNameField(primary_key=True, lookup_key=True, prefix='AS')), - ('descr', RPSLTextField(multiple=True, optional=True)), - ('members', RPSLReferenceListField(lookup_key=True, optional=True, multiple=True, referring=['aut-num', 'as-set'], strong=False)), - ('mbrs-by-ref', RPSLReferenceListField(lookup_key=True, optional=True, multiple=True, referring=['mntner'], allow_kw_any=True, strong=False)), - ('admin-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('tech-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('remarks', RPSLTextField(optional=True, multiple=True)), - ('notify', RPSLEmailField(optional=True, multiple=True)), - ('mnt-by', RPSLReferenceListField(lookup_key=True, multiple=True, referring=['mntner'])), - ('changed', RPSLChangedField(optional=True, multiple=True)), - ('source', RPSLGenericNameField()), - ]) + fields = OrderedDict( + [ + ("as-set", RPSLSetNameField(primary_key=True, lookup_key=True, prefix="AS")), + ("descr", RPSLTextField(multiple=True, optional=True)), + ( + "members", + RPSLReferenceListField( + lookup_key=True, + optional=True, + multiple=True, + referring=["aut-num", "as-set"], + strong=False, + ), + ), + ( + "mbrs-by-ref", + RPSLReferenceListField( + lookup_key=True, + optional=True, + multiple=True, + referring=["mntner"], + allow_kw_any=True, + strong=False, + ), + ), + ( + "admin-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ( + "tech-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ("remarks", RPSLTextField(optional=True, multiple=True)), + ("notify", RPSLEmailField(optional=True, multiple=True)), + ("mnt-by", RPSLReferenceListField(lookup_key=True, multiple=True, referring=["mntner"])), + ("changed", RPSLChangedField(optional=True, multiple=True)), + ("source", RPSLGenericNameField()), + ] + ) class RPSLAutNum(RPSLObject): - fields = OrderedDict([ - ('aut-num', RPSLASNumberField(primary_key=True, lookup_key=True)), - ('as-name', RPSLGenericNameField(allowed_prefixes=['AS'])), - ('descr', RPSLTextField(multiple=True, optional=True)), - ('member-of', RPSLReferenceListField(lookup_key=True, optional=True, multiple=True, referring=['as-set'], strong=False)), - ('import', RPSLTextField(optional=True, multiple=True)), - ('mp-import', RPSLTextField(optional=True, multiple=True)), - ('import-via', RPSLTextField(optional=True, multiple=True)), - ('export', RPSLTextField(optional=True, multiple=True)), - ('mp-export', RPSLTextField(optional=True, multiple=True)), - ('export-via', RPSLTextField(optional=True, multiple=True)), - ('default', RPSLTextField(optional=True, multiple=True)), - ('mp-default', RPSLTextField(optional=True, multiple=True)), - ('admin-c', RPSLReferenceField(lookup_key=True, multiple=True, referring=['role', 'person'])), - ('tech-c', RPSLReferenceField(lookup_key=True, multiple=True, referring=['role', 'person'])), - ('remarks', RPSLTextField(optional=True, multiple=True)), - ('notify', RPSLEmailField(optional=True, multiple=True)), - ('mnt-by', RPSLReferenceListField(lookup_key=True, optional=True, multiple=True, referring=['mntner'])), - ('changed', RPSLChangedField(optional=True, multiple=True)), - ('source', RPSLGenericNameField()), - ]) + fields = OrderedDict( + [ + ("aut-num", RPSLASNumberField(primary_key=True, lookup_key=True)), + ("as-name", RPSLGenericNameField(allowed_prefixes=["AS"])), + ("descr", RPSLTextField(multiple=True, optional=True)), + ( + "member-of", + RPSLReferenceListField( + lookup_key=True, optional=True, multiple=True, referring=["as-set"], strong=False + ), + ), + ("import", RPSLTextField(optional=True, multiple=True)), + ("mp-import", RPSLTextField(optional=True, multiple=True)), + ("import-via", RPSLTextField(optional=True, multiple=True)), + ("export", RPSLTextField(optional=True, multiple=True)), + ("mp-export", RPSLTextField(optional=True, multiple=True)), + ("export-via", RPSLTextField(optional=True, multiple=True)), + ("default", RPSLTextField(optional=True, multiple=True)), + ("mp-default", RPSLTextField(optional=True, multiple=True)), + ("admin-c", RPSLReferenceField(lookup_key=True, multiple=True, referring=["role", "person"])), + ("tech-c", RPSLReferenceField(lookup_key=True, multiple=True, referring=["role", "person"])), + ("remarks", RPSLTextField(optional=True, multiple=True)), + ("notify", RPSLEmailField(optional=True, multiple=True)), + ( + "mnt-by", + RPSLReferenceListField(lookup_key=True, optional=True, multiple=True, referring=["mntner"]), + ), + ("changed", RPSLChangedField(optional=True, multiple=True)), + ("source", RPSLGenericNameField()), + ] + ) class RPSLDomain(RPSLObject): - fields = OrderedDict([ - ('domain', RPSLTextField(primary_key=True, lookup_key=True)), # reverse delegation address (range), v4/v6/enum - ('descr', RPSLTextField(multiple=True, optional=True)), - ('admin-c', RPSLReferenceField(lookup_key=True, multiple=True, referring=['role', 'person'])), - ('tech-c', RPSLReferenceField(lookup_key=True, multiple=True, referring=['role', 'person'])), - ('zone-c', RPSLReferenceField(lookup_key=True, multiple=True, referring=['role', 'person'])), - ('nserver', RPSLTextField(optional=True, multiple=True)), # DNS name, possibly followed v4/v6 - ('sub-dom', RPSLTextField(optional=True, multiple=True)), - ('dom-net', RPSLTextField(optional=True, multiple=True)), - ('refer', RPSLTextField(optional=True)), # ??? - ('remarks', RPSLTextField(optional=True, multiple=True)), - ('notify', RPSLEmailField(optional=True, multiple=True)), - ('mnt-by', RPSLReferenceListField(lookup_key=True, optional=True, multiple=True, referring=['mntner'])), - ('changed', RPSLChangedField(optional=True, multiple=True)), - ('source', RPSLGenericNameField()), - ]) + fields = OrderedDict( + [ + ( + "domain", + RPSLTextField(primary_key=True, lookup_key=True), + ), # reverse delegation address (range), v4/v6/enum + ("descr", RPSLTextField(multiple=True, optional=True)), + ("admin-c", RPSLReferenceField(lookup_key=True, multiple=True, referring=["role", "person"])), + ("tech-c", RPSLReferenceField(lookup_key=True, multiple=True, referring=["role", "person"])), + ("zone-c", RPSLReferenceField(lookup_key=True, multiple=True, referring=["role", "person"])), + ("nserver", RPSLTextField(optional=True, multiple=True)), # DNS name, possibly followed v4/v6 + ("sub-dom", RPSLTextField(optional=True, multiple=True)), + ("dom-net", RPSLTextField(optional=True, multiple=True)), + ("refer", RPSLTextField(optional=True)), # ??? + ("remarks", RPSLTextField(optional=True, multiple=True)), + ("notify", RPSLEmailField(optional=True, multiple=True)), + ( + "mnt-by", + RPSLReferenceListField(lookup_key=True, optional=True, multiple=True, referring=["mntner"]), + ), + ("changed", RPSLChangedField(optional=True, multiple=True)), + ("source", RPSLGenericNameField()), + ] + ) class RPSLFilterSet(RPSLSet): - fields = OrderedDict([ - ('filter-set', RPSLSetNameField(primary_key=True, lookup_key=True, prefix='FLTR')), - ('descr', RPSLTextField(multiple=True, optional=True)), - ('filter', RPSLTextField()), - ('mp-filter', RPSLTextField(optional=True)), - ('admin-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('tech-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('remarks', RPSLTextField(optional=True, multiple=True)), - ('notify', RPSLEmailField(optional=True, multiple=True)), - ('mnt-by', RPSLReferenceListField(lookup_key=True, multiple=True, referring=['mntner'])), - ('changed', RPSLChangedField(optional=True, multiple=True)), - ('source', RPSLGenericNameField()), - ]) + fields = OrderedDict( + [ + ("filter-set", RPSLSetNameField(primary_key=True, lookup_key=True, prefix="FLTR")), + ("descr", RPSLTextField(multiple=True, optional=True)), + ("filter", RPSLTextField()), + ("mp-filter", RPSLTextField(optional=True)), + ( + "admin-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ( + "tech-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ("remarks", RPSLTextField(optional=True, multiple=True)), + ("notify", RPSLEmailField(optional=True, multiple=True)), + ("mnt-by", RPSLReferenceListField(lookup_key=True, multiple=True, referring=["mntner"])), + ("changed", RPSLChangedField(optional=True, multiple=True)), + ("source", RPSLGenericNameField()), + ] + ) class RPSLInetRtr(RPSLObject): - fields = OrderedDict([ - ('inet-rtr', RPSLDNSNameField(primary_key=True, lookup_key=True)), - ('descr', RPSLTextField(multiple=True, optional=True)), - ('alias', RPSLDNSNameField(optional=True, multiple=True)), - ('local-as', RPSLASNumberField()), - ('ifaddr', RPSLTextField(optional=True, multiple=True)), - ('interface', RPSLTextField(optional=True, multiple=True)), - ('peer', RPSLTextField(optional=True, multiple=True)), - ('mp-peer', RPSLTextField(optional=True, multiple=True)), - ('member-of', RPSLReferenceListField(lookup_key=True, optional=True, multiple=True, referring=['rtr-set'], strong=False)), - ('rs-in', RPSLTextField(optional=True)), - ('rs-out', RPSLTextField(optional=True)), - ('admin-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('tech-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('remarks', RPSLTextField(optional=True, multiple=True)), - ('notify', RPSLEmailField(optional=True, multiple=True)), - ('mnt-by', RPSLReferenceListField(lookup_key=True, multiple=True, referring=['mntner'])), - ('changed', RPSLChangedField(optional=True, multiple=True)), - ('source', RPSLGenericNameField()), - ]) + fields = OrderedDict( + [ + ("inet-rtr", RPSLDNSNameField(primary_key=True, lookup_key=True)), + ("descr", RPSLTextField(multiple=True, optional=True)), + ("alias", RPSLDNSNameField(optional=True, multiple=True)), + ("local-as", RPSLASNumberField()), + ("ifaddr", RPSLTextField(optional=True, multiple=True)), + ("interface", RPSLTextField(optional=True, multiple=True)), + ("peer", RPSLTextField(optional=True, multiple=True)), + ("mp-peer", RPSLTextField(optional=True, multiple=True)), + ( + "member-of", + RPSLReferenceListField( + lookup_key=True, optional=True, multiple=True, referring=["rtr-set"], strong=False + ), + ), + ("rs-in", RPSLTextField(optional=True)), + ("rs-out", RPSLTextField(optional=True)), + ( + "admin-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ( + "tech-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ("remarks", RPSLTextField(optional=True, multiple=True)), + ("notify", RPSLEmailField(optional=True, multiple=True)), + ("mnt-by", RPSLReferenceListField(lookup_key=True, multiple=True, referring=["mntner"])), + ("changed", RPSLChangedField(optional=True, multiple=True)), + ("source", RPSLGenericNameField()), + ] + ) class RPSLInet6Num(RPSLObject): - fields = OrderedDict([ - ('inet6num', RPSLIPv6PrefixField(primary_key=True, lookup_key=True)), - ('netname', RPSLTextField()), - ('descr', RPSLTextField(multiple=True, optional=True)), - ('country', RPSLTextField(multiple=True)), - ('admin-c', RPSLReferenceField(lookup_key=True, multiple=True, referring=['role', 'person'])), - ('tech-c', RPSLReferenceField(lookup_key=True, multiple=True, referring=['role', 'person'])), - ('rev-srv', RPSLTextField(optional=True, multiple=True)), - ('status', RPSLTextField()), - ('geofeed', RPSLURLField(optional=True)), - ('remarks', RPSLTextField(optional=True, multiple=True)), - ('notify', RPSLEmailField(optional=True, multiple=True)), - ('mnt-by', RPSLReferenceListField(lookup_key=True, multiple=True, referring=['mntner'])), - ('changed', RPSLChangedField(optional=True, multiple=True)), - ('source', RPSLGenericNameField()), - ]) + fields = OrderedDict( + [ + ("inet6num", RPSLIPv6PrefixField(primary_key=True, lookup_key=True)), + ("netname", RPSLTextField()), + ("descr", RPSLTextField(multiple=True, optional=True)), + ("country", RPSLTextField(multiple=True)), + ("admin-c", RPSLReferenceField(lookup_key=True, multiple=True, referring=["role", "person"])), + ("tech-c", RPSLReferenceField(lookup_key=True, multiple=True, referring=["role", "person"])), + ("rev-srv", RPSLTextField(optional=True, multiple=True)), + ("status", RPSLTextField()), + ("geofeed", RPSLURLField(optional=True)), + ("remarks", RPSLTextField(optional=True, multiple=True)), + ("notify", RPSLEmailField(optional=True, multiple=True)), + ("mnt-by", RPSLReferenceListField(lookup_key=True, multiple=True, referring=["mntner"])), + ("changed", RPSLChangedField(optional=True, multiple=True)), + ("source", RPSLGenericNameField()), + ] + ) class RPSLInetnum(RPSLObject): - fields = OrderedDict([ - ('inetnum', RPSLIPv4AddressRangeField(primary_key=True, lookup_key=True)), - ('netname', RPSLTextField()), - ('descr', RPSLTextField(multiple=True, optional=True)), - ('country', RPSLTextField(multiple=True)), - ('admin-c', RPSLReferenceField(lookup_key=True, multiple=True, referring=['role', 'person'])), - ('tech-c', RPSLReferenceField(lookup_key=True, multiple=True, referring=['role', 'person'])), - ('rev-srv', RPSLTextField(optional=True, multiple=True)), - ('status', RPSLTextField()), - ('geofeed', RPSLURLField(optional=True)), - ('remarks', RPSLTextField(optional=True, multiple=True)), - ('notify', RPSLEmailField(optional=True, multiple=True)), - ('mnt-by', RPSLReferenceListField(lookup_key=True, multiple=True, referring=['mntner'])), - ('changed', RPSLChangedField(optional=True, multiple=True)), - ('source', RPSLGenericNameField()), - ]) + fields = OrderedDict( + [ + ("inetnum", RPSLIPv4AddressRangeField(primary_key=True, lookup_key=True)), + ("netname", RPSLTextField()), + ("descr", RPSLTextField(multiple=True, optional=True)), + ("country", RPSLTextField(multiple=True)), + ("admin-c", RPSLReferenceField(lookup_key=True, multiple=True, referring=["role", "person"])), + ("tech-c", RPSLReferenceField(lookup_key=True, multiple=True, referring=["role", "person"])), + ("rev-srv", RPSLTextField(optional=True, multiple=True)), + ("status", RPSLTextField()), + ("geofeed", RPSLURLField(optional=True)), + ("remarks", RPSLTextField(optional=True, multiple=True)), + ("notify", RPSLEmailField(optional=True, multiple=True)), + ("mnt-by", RPSLReferenceListField(lookup_key=True, multiple=True, referring=["mntner"])), + ("changed", RPSLChangedField(optional=True, multiple=True)), + ("source", RPSLGenericNameField()), + ] + ) class RPSLKeyCert(RPSLObject): - fields = OrderedDict([ - ('key-cert', RPSLGenericNameField(primary_key=True, lookup_key=True)), - ('method', RPSLTextField(optional=True)), # Fixed to PGP - ('owner', RPSLTextField(optional=True, multiple=True)), # key owner, autogenerate - ('fingerpr', RPSLTextField(optional=True)), # fingerprint, autogenerate - ('certif', RPSLTextField(multiple=True)), # Actual key - ('remarks', RPSLTextField(optional=True, multiple=True)), - ('admin-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('tech-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('notify', RPSLEmailField(optional=True, multiple=True)), - ('mnt-by', RPSLReferenceListField(lookup_key=True, multiple=True, referring=['mntner'])), - ('changed', RPSLChangedField(optional=True, multiple=True)), - ('source', RPSLGenericNameField()), - ]) + fields = OrderedDict( + [ + ("key-cert", RPSLGenericNameField(primary_key=True, lookup_key=True)), + ("method", RPSLTextField(optional=True)), # Fixed to PGP + ("owner", RPSLTextField(optional=True, multiple=True)), # key owner, autogenerate + ("fingerpr", RPSLTextField(optional=True)), # fingerprint, autogenerate + ("certif", RPSLTextField(multiple=True)), # Actual key + ("remarks", RPSLTextField(optional=True, multiple=True)), + ( + "admin-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ( + "tech-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ("notify", RPSLEmailField(optional=True, multiple=True)), + ("mnt-by", RPSLReferenceListField(lookup_key=True, multiple=True, referring=["mntner"])), + ("changed", RPSLChangedField(optional=True, multiple=True)), + ("source", RPSLGenericNameField()), + ] + ) def clean(self) -> bool: """ @@ -231,31 +347,31 @@ def clean(self) -> bool: return False # pragma: no cover gpg = get_gpg_instance() - certif_data = '\n'.join(self.parsed_data.get('certif', [])).replace(',', '\n') + certif_data = "\n".join(self.parsed_data.get("certif", [])).replace(",", "\n") result = gpg.import_keys(certif_data) if len(result.fingerprints) != 1: - msg = 'Unable to read public PGP key: key corrupt or multiple keys provided' + msg = "Unable to read public PGP key: key corrupt or multiple keys provided" if result.results: msg = f'{msg}: {result.results[0]["text"]}' self.messages.error(msg) return False self.fingerprint = result.fingerprints[0] - expected_object_name = 'PGPKEY-' + self.fingerprint[-8:] - actual_object_name = self.parsed_data['key-cert'].upper() + expected_object_name = "PGPKEY-" + self.fingerprint[-8:] + actual_object_name = self.parsed_data["key-cert"].upper() fingerprint_formatted = self.format_fingerprint(self.fingerprint) if expected_object_name != actual_object_name: self.messages.error( - f'Invalid object name {actual_object_name}: does not match key fingerprint {fingerprint_formatted}, ' - f'expected object name {expected_object_name}' + f"Invalid object name {actual_object_name}: does not match key fingerprint" + f" {fingerprint_formatted}, expected object name {expected_object_name}" ) return False - self._update_attribute_value('fingerpr', fingerprint_formatted) - self._update_attribute_value('owner', gpg.list_keys(keys=self.fingerprint)[0]['uids']) - self._update_attribute_value('method', 'PGP') + self._update_attribute_value("fingerpr", fingerprint_formatted) + self._update_attribute_value("owner", gpg.list_keys(keys=self.fingerprint)[0]["uids"]) + self._update_attribute_value("method", "PGP") return True @@ -267,35 +383,45 @@ def clean(self) -> bool: def verify(self, message: str) -> bool: gpg = get_gpg_instance() result = gpg.verify(message) - return result.valid and result.key_status is None and \ - self.format_fingerprint(result.fingerprint) == self.parsed_data['fingerpr'] + return ( + result.valid + and result.key_status is None + and self.format_fingerprint(result.fingerprint) == self.parsed_data["fingerpr"] + ) @staticmethod def format_fingerprint(fingerprint: str) -> str: """Format a PGP fingerprint into sections of 4 characters, separated by spaces.""" string_parts = [] for idx in range(0, 40, 4): - string_parts.append(fingerprint[idx:idx + 4]) + string_parts.append(fingerprint[idx : idx + 4]) if idx == 16: - string_parts.append('') - return ' '.join(string_parts) + string_parts.append("") + return " ".join(string_parts) class RPSLMntner(RPSLObject): - fields = OrderedDict([ - ('mntner', RPSLGenericNameField(primary_key=True, lookup_key=True)), - ('descr', RPSLTextField(multiple=True, optional=True)), - ('admin-c', RPSLReferenceField(lookup_key=True, multiple=True, referring=['role', 'person'])), - ('tech-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('upd-to', RPSLEmailField(multiple=True)), - ('mnt-nfy', RPSLEmailField(optional=True, multiple=True)), - ('auth', RPSLAuthField(multiple=True)), - ('remarks', RPSLTextField(optional=True, multiple=True)), - ('notify', RPSLEmailField(optional=True, multiple=True)), - ('mnt-by', RPSLReferenceListField(lookup_key=True, multiple=True, referring=['mntner'])), - ('changed', RPSLChangedField(optional=True, multiple=True)), - ('source', RPSLGenericNameField()), - ]) + fields = OrderedDict( + [ + ("mntner", RPSLGenericNameField(primary_key=True, lookup_key=True)), + ("descr", RPSLTextField(multiple=True, optional=True)), + ("admin-c", RPSLReferenceField(lookup_key=True, multiple=True, referring=["role", "person"])), + ( + "tech-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ("upd-to", RPSLEmailField(multiple=True)), + ("mnt-nfy", RPSLEmailField(optional=True, multiple=True)), + ("auth", RPSLAuthField(multiple=True)), + ("remarks", RPSLTextField(optional=True, multiple=True)), + ("notify", RPSLEmailField(optional=True, multiple=True)), + ("mnt-by", RPSLReferenceListField(lookup_key=True, multiple=True, referring=["mntner"])), + ("changed", RPSLChangedField(optional=True, multiple=True)), + ("source", RPSLGenericNameField()), + ] + ) def clean(self): """Check whether either all hash values are dummy hashes, or none.""" @@ -304,21 +430,23 @@ def clean(self): dummy_matches = [auth[1] == PASSWORD_HASH_DUMMY_VALUE for auth in self._auth_lines(True)] if any(dummy_matches) and not all(dummy_matches): - self.messages.error('Either all password auth hashes in a submitted mntner must be dummy objects, or none.') + self.messages.error( + "Either all password auth hashes in a submitted mntner must be dummy objects, or none." + ) - def verify_auth(self, passwords: List[str], keycert_obj_pk: Optional[str]=None) -> bool: + def verify_auth(self, passwords: List[str], keycert_obj_pk: Optional[str] = None) -> bool: """ Verify whether one of a given list of passwords matches any of the auth hashes in this object, or match the keycert object PK. """ hashers = get_password_hashers(permit_legacy=True) - for auth in self.parsed_data.get('auth', []): + for auth in self.parsed_data.get("auth", []): if keycert_obj_pk and auth.upper() == keycert_obj_pk.upper(): return True - if ' ' not in auth: + if " " not in auth: continue - scheme, hash = auth.split(' ', 1) + scheme, hash = auth.split(" ", 1) hasher = hashers.get(scheme.upper()) if hasher: for password in passwords: @@ -344,10 +472,10 @@ def force_single_new_password(self, password) -> None: Retains other methods, i.e. PGPKEY. """ hash_key, hash_function = PASSWORD_REPLACEMENT_HASH - hash = hash_key + ' ' + hash_function.hash(password) + hash = hash_key + " " + hash_function.hash(password) auths = self._auth_lines(password_hashes=False) auths.append(hash) - self._update_attribute_value('auth', auths) + self._update_attribute_value("auth", auths) def _auth_lines(self, password_hashes=True) -> List[Union[str, List[str]]]: """ @@ -356,179 +484,305 @@ def _auth_lines(self, password_hashes=True) -> List[Union[str, List[str]]]: If password_hashes=True, returns a list of lists, each inner list containing the hash method and the hash. """ - lines = self.parsed_data.get('auth', []) + lines = self.parsed_data.get("auth", []) if password_hashes is True: - return [auth.split(' ', 1) for auth in lines if ' ' in auth] - return [auth for auth in lines if ' ' not in auth] + return [auth.split(" ", 1) for auth in lines if " " in auth] + return [auth for auth in lines if " " not in auth] class RPSLPeeringSet(RPSLSet): - fields = OrderedDict([ - ('peering-set', RPSLSetNameField(primary_key=True, lookup_key=True, prefix='PRNG')), - ('descr', RPSLTextField(multiple=True, optional=True)), - ('peering', RPSLTextField(optional=True, multiple=True)), - ('mp-peering', RPSLTextField(optional=True, multiple=True)), - ('admin-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('tech-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('remarks', RPSLTextField(optional=True, multiple=True)), - ('notify', RPSLEmailField(optional=True, multiple=True)), - ('mnt-by', RPSLReferenceListField(lookup_key=True, multiple=True, referring=['mntner'])), - ('changed', RPSLChangedField(optional=True, multiple=True)), - ('source', RPSLGenericNameField()), - ]) + fields = OrderedDict( + [ + ("peering-set", RPSLSetNameField(primary_key=True, lookup_key=True, prefix="PRNG")), + ("descr", RPSLTextField(multiple=True, optional=True)), + ("peering", RPSLTextField(optional=True, multiple=True)), + ("mp-peering", RPSLTextField(optional=True, multiple=True)), + ( + "admin-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ( + "tech-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ("remarks", RPSLTextField(optional=True, multiple=True)), + ("notify", RPSLEmailField(optional=True, multiple=True)), + ("mnt-by", RPSLReferenceListField(lookup_key=True, multiple=True, referring=["mntner"])), + ("changed", RPSLChangedField(optional=True, multiple=True)), + ("source", RPSLGenericNameField()), + ] + ) class RPSLPerson(RPSLObject): - fields = OrderedDict([ - ('person', RPSLTextField(lookup_key=True)), - ('address', RPSLTextField(multiple=True)), - ('phone', RPSLTextField(multiple=True)), - ('fax-no', RPSLTextField(optional=True, multiple=True)), - ('e-mail', RPSLEmailField(multiple=True)), - ('nic-hdl', RPSLGenericNameField(primary_key=True, lookup_key=True, non_strict_allow_any=True)), - ('remarks', RPSLTextField(optional=True, multiple=True)), - ('notify', RPSLEmailField(optional=True, multiple=True)), - ('mnt-by', RPSLReferenceListField(lookup_key=True, multiple=True, referring=['mntner'])), - ('changed', RPSLChangedField(optional=True, multiple=True)), - ('source', RPSLGenericNameField()), - ]) + fields = OrderedDict( + [ + ("person", RPSLTextField(lookup_key=True)), + ("address", RPSLTextField(multiple=True)), + ("phone", RPSLTextField(multiple=True)), + ("fax-no", RPSLTextField(optional=True, multiple=True)), + ("e-mail", RPSLEmailField(multiple=True)), + ("nic-hdl", RPSLGenericNameField(primary_key=True, lookup_key=True, non_strict_allow_any=True)), + ("remarks", RPSLTextField(optional=True, multiple=True)), + ("notify", RPSLEmailField(optional=True, multiple=True)), + ("mnt-by", RPSLReferenceListField(lookup_key=True, multiple=True, referring=["mntner"])), + ("changed", RPSLChangedField(optional=True, multiple=True)), + ("source", RPSLGenericNameField()), + ] + ) class RPSLRole(RPSLObject): - fields = OrderedDict([ - ('role', RPSLTextField(lookup_key=True)), - ('trouble', RPSLTextField(optional=True, multiple=True)), - ('address', RPSLTextField(multiple=True)), - ('phone', RPSLTextField(multiple=True)), - ('fax-no', RPSLTextField(optional=True, multiple=True)), - ('e-mail', RPSLEmailField(multiple=True)), - ('admin-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('tech-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('nic-hdl', RPSLGenericNameField(primary_key=True, lookup_key=True, non_strict_allow_any=True)), - ('remarks', RPSLTextField(optional=True, multiple=True)), - ('notify', RPSLEmailField(optional=True, multiple=True)), - ('mnt-by', RPSLReferenceListField(lookup_key=True, multiple=True, referring=['mntner'])), - ('changed', RPSLChangedField(optional=True, multiple=True)), - ('source', RPSLGenericNameField()), - ]) + fields = OrderedDict( + [ + ("role", RPSLTextField(lookup_key=True)), + ("trouble", RPSLTextField(optional=True, multiple=True)), + ("address", RPSLTextField(multiple=True)), + ("phone", RPSLTextField(multiple=True)), + ("fax-no", RPSLTextField(optional=True, multiple=True)), + ("e-mail", RPSLEmailField(multiple=True)), + ( + "admin-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ( + "tech-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ("nic-hdl", RPSLGenericNameField(primary_key=True, lookup_key=True, non_strict_allow_any=True)), + ("remarks", RPSLTextField(optional=True, multiple=True)), + ("notify", RPSLEmailField(optional=True, multiple=True)), + ("mnt-by", RPSLReferenceListField(lookup_key=True, multiple=True, referring=["mntner"])), + ("changed", RPSLChangedField(optional=True, multiple=True)), + ("source", RPSLGenericNameField()), + ] + ) class RPSLRoute(RPSLObject): is_route = True - discarded_fields = ['rpki-ov-state'] - fields = OrderedDict([ - ('route', RPSLIPv4PrefixField(primary_key=True, lookup_key=True)), - ('descr', RPSLTextField(multiple=True, optional=True)), - ('origin', RPSLASNumberField(primary_key=True)), - ('holes', RPSLIPv4PrefixesField(optional=True, multiple=True)), - ('member-of', RPSLReferenceListField(lookup_key=True, optional=True, multiple=True, referring=['route-set'], strong=False)), - ('inject', RPSLTextField(optional=True, multiple=True)), - ('aggr-bndry', RPSLTextField(optional=True)), - ('aggr-mtd', RPSLTextField(optional=True)), - ('export-comps', RPSLTextField(optional=True)), - ('components', RPSLTextField(optional=True)), - ('admin-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('tech-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('geoidx', RPSLTextField(optional=True, multiple=True)), - ('roa-uri', RPSLTextField(optional=True)), - ('remarks', RPSLTextField(optional=True, multiple=True)), - ('notify', RPSLEmailField(optional=True, multiple=True)), - ('mnt-by', RPSLReferenceListField(lookup_key=True, multiple=True, referring=['mntner'])), - ('changed', RPSLChangedField(optional=True, multiple=True)), - ('source', RPSLGenericNameField()), - ]) + discarded_fields = ["rpki-ov-state"] + fields = OrderedDict( + [ + ("route", RPSLIPv4PrefixField(primary_key=True, lookup_key=True)), + ("descr", RPSLTextField(multiple=True, optional=True)), + ("origin", RPSLASNumberField(primary_key=True)), + ("holes", RPSLIPv4PrefixesField(optional=True, multiple=True)), + ( + "member-of", + RPSLReferenceListField( + lookup_key=True, optional=True, multiple=True, referring=["route-set"], strong=False + ), + ), + ("inject", RPSLTextField(optional=True, multiple=True)), + ("aggr-bndry", RPSLTextField(optional=True)), + ("aggr-mtd", RPSLTextField(optional=True)), + ("export-comps", RPSLTextField(optional=True)), + ("components", RPSLTextField(optional=True)), + ( + "admin-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ( + "tech-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ("geoidx", RPSLTextField(optional=True, multiple=True)), + ("roa-uri", RPSLTextField(optional=True)), + ("remarks", RPSLTextField(optional=True, multiple=True)), + ("notify", RPSLEmailField(optional=True, multiple=True)), + ("mnt-by", RPSLReferenceListField(lookup_key=True, multiple=True, referring=["mntner"])), + ("changed", RPSLChangedField(optional=True, multiple=True)), + ("source", RPSLGenericNameField()), + ] + ) class RPSLRouteSet(RPSLSet): - fields = OrderedDict([ - ('route-set', RPSLSetNameField(primary_key=True, lookup_key=True, prefix='RS')), - ('members', RPSLRouteSetMembersField(ip_version=4, lookup_key=True, optional=True, multiple=True)), - ('mp-members', RPSLRouteSetMembersField(ip_version=None, lookup_key=True, optional=True, multiple=True)), - ('mbrs-by-ref', RPSLReferenceListField(lookup_key=True, optional=True, multiple=True, referring=['mntner'], allow_kw_any=True, strong=False)), - ('descr', RPSLTextField(multiple=True, optional=True)), - ('admin-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('tech-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('remarks', RPSLTextField(optional=True, multiple=True)), - ('notify', RPSLEmailField(optional=True, multiple=True)), - ('mnt-by', RPSLReferenceListField(lookup_key=True, multiple=True, referring=['mntner'])), - ('changed', RPSLChangedField(optional=True, multiple=True)), - ('source', RPSLGenericNameField()), - ]) + fields = OrderedDict( + [ + ("route-set", RPSLSetNameField(primary_key=True, lookup_key=True, prefix="RS")), + ( + "members", + RPSLRouteSetMembersField(ip_version=4, lookup_key=True, optional=True, multiple=True), + ), + ( + "mp-members", + RPSLRouteSetMembersField(ip_version=None, lookup_key=True, optional=True, multiple=True), + ), + ( + "mbrs-by-ref", + RPSLReferenceListField( + lookup_key=True, + optional=True, + multiple=True, + referring=["mntner"], + allow_kw_any=True, + strong=False, + ), + ), + ("descr", RPSLTextField(multiple=True, optional=True)), + ( + "admin-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ( + "tech-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ("remarks", RPSLTextField(optional=True, multiple=True)), + ("notify", RPSLEmailField(optional=True, multiple=True)), + ("mnt-by", RPSLReferenceListField(lookup_key=True, multiple=True, referring=["mntner"])), + ("changed", RPSLChangedField(optional=True, multiple=True)), + ("source", RPSLGenericNameField()), + ] + ) class RPSLRoute6(RPSLObject): is_route = True - discarded_fields = ['rpki-ov-state'] - fields = OrderedDict([ - ('route6', RPSLIPv6PrefixField(primary_key=True, lookup_key=True)), - ('descr', RPSLTextField(multiple=True, optional=True)), - ('origin', RPSLASNumberField(primary_key=True)), - ('holes', RPSLIPv6PrefixesField(optional=True, multiple=True)), - ('member-of', RPSLReferenceListField(lookup_key=True, optional=True, multiple=True, referring=['route-set'], strong=False)), - ('inject', RPSLTextField(optional=True, multiple=True)), - ('aggr-bndry', RPSLTextField(optional=True)), - ('aggr-mtd', RPSLTextField(optional=True)), - ('export-comps', RPSLTextField(optional=True)), - ('components', RPSLTextField(optional=True)), - ('admin-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('tech-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('geoidx', RPSLTextField(optional=True, multiple=True)), - ('roa-uri', RPSLTextField(optional=True)), - ('remarks', RPSLTextField(optional=True, multiple=True)), - ('notify', RPSLEmailField(optional=True, multiple=True)), - ('mnt-by', RPSLReferenceListField(lookup_key=True, multiple=True, referring=['mntner'])), - ('changed', RPSLChangedField(optional=True, multiple=True)), - ('source', RPSLGenericNameField()), - ]) + discarded_fields = ["rpki-ov-state"] + fields = OrderedDict( + [ + ("route6", RPSLIPv6PrefixField(primary_key=True, lookup_key=True)), + ("descr", RPSLTextField(multiple=True, optional=True)), + ("origin", RPSLASNumberField(primary_key=True)), + ("holes", RPSLIPv6PrefixesField(optional=True, multiple=True)), + ( + "member-of", + RPSLReferenceListField( + lookup_key=True, optional=True, multiple=True, referring=["route-set"], strong=False + ), + ), + ("inject", RPSLTextField(optional=True, multiple=True)), + ("aggr-bndry", RPSLTextField(optional=True)), + ("aggr-mtd", RPSLTextField(optional=True)), + ("export-comps", RPSLTextField(optional=True)), + ("components", RPSLTextField(optional=True)), + ( + "admin-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ( + "tech-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ("geoidx", RPSLTextField(optional=True, multiple=True)), + ("roa-uri", RPSLTextField(optional=True)), + ("remarks", RPSLTextField(optional=True, multiple=True)), + ("notify", RPSLEmailField(optional=True, multiple=True)), + ("mnt-by", RPSLReferenceListField(lookup_key=True, multiple=True, referring=["mntner"])), + ("changed", RPSLChangedField(optional=True, multiple=True)), + ("source", RPSLGenericNameField()), + ] + ) class RPSLRtrSet(RPSLSet): - fields = OrderedDict([ - ('rtr-set', RPSLSetNameField(primary_key=True, lookup_key=True, prefix='RTRS')), - ('descr', RPSLTextField(multiple=True, optional=True)), - ('members', RPSLReferenceListField(lookup_key=True, optional=True, multiple=True, referring=['inet-rtr', 'rtr-set'], strong=False)), - ('mp-members', RPSLReferenceListField(lookup_key=True, optional=True, multiple=True, referring=['inet-rtr', 'rtr-set'], strong=False)), - ('mbrs-by-ref', RPSLReferenceListField(lookup_key=True, optional=True, multiple=True, referring=['mntner'], allow_kw_any=True, strong=False)), - ('admin-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('tech-c', RPSLReferenceField(lookup_key=True, optional=True, multiple=True, referring=['role', 'person'])), - ('remarks', RPSLTextField(optional=True, multiple=True)), - ('notify', RPSLEmailField(optional=True, multiple=True)), - ('mnt-by', RPSLReferenceListField(lookup_key=True, multiple=True, referring=['mntner'])), - ('changed', RPSLChangedField(optional=True, multiple=True)), - ('source', RPSLGenericNameField()), - ]) + fields = OrderedDict( + [ + ("rtr-set", RPSLSetNameField(primary_key=True, lookup_key=True, prefix="RTRS")), + ("descr", RPSLTextField(multiple=True, optional=True)), + ( + "members", + RPSLReferenceListField( + lookup_key=True, + optional=True, + multiple=True, + referring=["inet-rtr", "rtr-set"], + strong=False, + ), + ), + ( + "mp-members", + RPSLReferenceListField( + lookup_key=True, + optional=True, + multiple=True, + referring=["inet-rtr", "rtr-set"], + strong=False, + ), + ), + ( + "mbrs-by-ref", + RPSLReferenceListField( + lookup_key=True, + optional=True, + multiple=True, + referring=["mntner"], + allow_kw_any=True, + strong=False, + ), + ), + ( + "admin-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ( + "tech-c", + RPSLReferenceField( + lookup_key=True, optional=True, multiple=True, referring=["role", "person"] + ), + ), + ("remarks", RPSLTextField(optional=True, multiple=True)), + ("notify", RPSLEmailField(optional=True, multiple=True)), + ("mnt-by", RPSLReferenceListField(lookup_key=True, multiple=True, referring=["mntner"])), + ("changed", RPSLChangedField(optional=True, multiple=True)), + ("source", RPSLGenericNameField()), + ] + ) OBJECT_CLASS_MAPPING = { - 'as-block': RPSLAsBlock, - 'as-set': RPSLAsSet, - 'aut-num': RPSLAutNum, - 'domain': RPSLDomain, - 'filter-set': RPSLFilterSet, - 'inet-rtr': RPSLInetRtr, - 'inet6num': RPSLInet6Num, - 'inetnum': RPSLInetnum, - 'key-cert': RPSLKeyCert, - 'mntner': RPSLMntner, - 'peering-set': RPSLPeeringSet, - 'person': RPSLPerson, - 'role': RPSLRole, - 'route': RPSLRoute, - 'route-set': RPSLRouteSet, - 'route6': RPSLRoute6, - 'rtr-set': RPSLRtrSet, + "as-block": RPSLAsBlock, + "as-set": RPSLAsSet, + "aut-num": RPSLAutNum, + "domain": RPSLDomain, + "filter-set": RPSLFilterSet, + "inet-rtr": RPSLInetRtr, + "inet6num": RPSLInet6Num, + "inetnum": RPSLInetnum, + "key-cert": RPSLKeyCert, + "mntner": RPSLMntner, + "peering-set": RPSLPeeringSet, + "person": RPSLPerson, + "role": RPSLRole, + "route": RPSLRoute, + "route-set": RPSLRouteSet, + "route6": RPSLRoute6, + "rtr-set": RPSLRtrSet, } RPKI_RELEVANT_OBJECT_CLASSES = [ - rpsl_object.rpsl_object_class - for rpsl_object in OBJECT_CLASS_MAPPING.values() - if rpsl_object.is_route + rpsl_object.rpsl_object_class for rpsl_object in OBJECT_CLASS_MAPPING.values() if rpsl_object.is_route ] def lookup_field_names() -> Set[str]: """Return all unique names of all lookup keys in all objects, plus 'origin'.""" - names = {'origin'} + names = {"origin"} for object_class in OBJECT_CLASS_MAPPING.values(): names.update([f for f in object_class.lookup_fields if f not in object_class.pk_fields]) return names diff --git a/irrd/rpsl/tests/test_fields.py b/irrd/rpsl/tests/test_fields.py index e495bb3ee..eda7ba5a7 100644 --- a/irrd/rpsl/tests/test_fields.py +++ b/irrd/rpsl/tests/test_fields.py @@ -1,13 +1,26 @@ from IPy import IP from pytest import raises -from ..fields import (RPSLIPv4PrefixField, RPSLIPv4PrefixesField, RPSLIPv6PrefixField, - RPSLIPv6PrefixesField, RPSLIPv4AddressRangeField, RPSLASNumberField, - RPSLASBlockField, - RPSLSetNameField, RPSLEmailField, RPSLDNSNameField, RPSLGenericNameField, - RPSLReferenceField, - RPSLReferenceListField, RPSLTextField, RPSLAuthField, RPSLRouteSetMemberField, - RPSLChangedField, RPSLURLField) +from ..fields import ( + RPSLASBlockField, + RPSLASNumberField, + RPSLAuthField, + RPSLChangedField, + RPSLDNSNameField, + RPSLEmailField, + RPSLGenericNameField, + RPSLIPv4AddressRangeField, + RPSLIPv4PrefixesField, + RPSLIPv4PrefixField, + RPSLIPv6PrefixesField, + RPSLIPv6PrefixField, + RPSLReferenceField, + RPSLReferenceListField, + RPSLRouteSetMemberField, + RPSLSetNameField, + RPSLTextField, + RPSLURLField, +) from ..parser_state import RPSLParserMessages @@ -31,14 +44,14 @@ def assert_validation_err(expected_errors, callable, *args, **kwargs): expected_errors = [e for e in expected_errors if e not in matched_expected_errors] errors = [e for e in errors if e not in matched_errors] - assert len(errors) == 0, f'unexpected error messages in: {messages.errors()}' - assert len(expected_errors) == 0, f'did not find error messages: {expected_errors}' + assert len(errors) == 0, f"unexpected error messages in: {messages.errors()}" + assert len(expected_errors) == 0, f"did not find error messages: {expected_errors}" def test_rpsl_text_field(): field = RPSLTextField() messages = RPSLParserMessages() - assert field.parse('AS-FOO$', messages).value, 'AS-FOO$' + assert field.parse("AS-FOO$", messages).value, "AS-FOO$" assert not messages.errors() @@ -46,112 +59,115 @@ def test_ipv4_prefix_field(): field = RPSLIPv4PrefixField() messages = RPSLParserMessages() - parse_result = field.parse('192.0.2.0/24', messages) - assert parse_result.value == '192.0.2.0/24' - assert parse_result.ip_first == IP('192.0.2.0') - assert parse_result.ip_last == IP('192.0.2.255') + parse_result = field.parse("192.0.2.0/24", messages) + assert parse_result.value == "192.0.2.0/24" + assert parse_result.ip_first == IP("192.0.2.0") + assert parse_result.ip_last == IP("192.0.2.255") assert parse_result.prefix_length == 24 - assert field.parse('192.00.02.0/25', messages).value == '192.0.2.0/25' - assert field.parse('192.0.2.0/32', messages).value == '192.0.2.0/32' + assert field.parse("192.00.02.0/25", messages).value == "192.0.2.0/25" + assert field.parse("192.0.2.0/32", messages).value == "192.0.2.0/32" assert not messages.errors() - assert messages.infos() == ['Address prefix 192.00.02.0/25 was reformatted as 192.0.2.0/25'] + assert messages.infos() == ["Address prefix 192.00.02.0/25 was reformatted as 192.0.2.0/25"] # 192.0.2/24 is generally seen as a valid prefix, but RFC 2622 does not allow this notation. - assert_validation_err('Invalid address prefix', field.parse, '192.0.2/24') - assert_validation_err('Invalid address prefix', field.parse, '555.555.555.555/24') - assert_validation_err('Invalid address prefix', field.parse, 'foo') - assert_validation_err('Invalid address prefix', field.parse, '2001::/32') - assert_validation_err('Invalid address prefix', field.parse, '192.0.2.0/16') + assert_validation_err("Invalid address prefix", field.parse, "192.0.2/24") + assert_validation_err("Invalid address prefix", field.parse, "555.555.555.555/24") + assert_validation_err("Invalid address prefix", field.parse, "foo") + assert_validation_err("Invalid address prefix", field.parse, "2001::/32") + assert_validation_err("Invalid address prefix", field.parse, "192.0.2.0/16") def test_ipv4_prefixes_field(): field = RPSLIPv4PrefixesField() messages = RPSLParserMessages() - assert field.parse('192.0.2.0/24', messages).value == '192.0.2.0/24' + assert field.parse("192.0.2.0/24", messages).value == "192.0.2.0/24" # Technically, the trailing comma is not RFC-compliant. # However, it's used in some cases when the list is broken over # multiple lines, and accepting it is harmless. - parse_result = field.parse('192.0.2.0/24, 192.00.02.0/25, ', messages) - assert parse_result.value == '192.0.2.0/24,192.0.2.0/25' - assert parse_result.values_list == ['192.0.2.0/24', '192.0.2.0/25'] + parse_result = field.parse("192.0.2.0/24, 192.00.02.0/25, ", messages) + assert parse_result.value == "192.0.2.0/24,192.0.2.0/25" + assert parse_result.values_list == ["192.0.2.0/24", "192.0.2.0/25"] assert not messages.errors() - assert messages.infos() == ['Address prefix 192.00.02.0/25 was reformatted as 192.0.2.0/25'] + assert messages.infos() == ["Address prefix 192.00.02.0/25 was reformatted as 192.0.2.0/25"] - assert_validation_err('Invalid address prefix', field.parse, '192.0.2.0/24, 192.0.2/16') + assert_validation_err("Invalid address prefix", field.parse, "192.0.2.0/24, 192.0.2/16") def test_ipv6_prefix_field(): field = RPSLIPv6PrefixField() messages = RPSLParserMessages() - parse_result = field.parse('12AB:0000:0000:CD30:0000:0000:0000:0000/60', messages) - assert parse_result.value == '12ab:0:0:cd30::/60' - assert parse_result.ip_first == IP('12ab:0:0:cd30::') - assert parse_result.ip_last == IP('12ab::cd3f:ffff:ffff:ffff:ffff') + parse_result = field.parse("12AB:0000:0000:CD30:0000:0000:0000:0000/60", messages) + assert parse_result.value == "12ab:0:0:cd30::/60" + assert parse_result.ip_first == IP("12ab:0:0:cd30::") + assert parse_result.ip_last == IP("12ab::cd3f:ffff:ffff:ffff:ffff") assert parse_result.prefix_length == 60 - assert field.parse('12ab::cd30:0:0:0:0/60', messages).value == '12ab:0:0:cd30::/60' - assert field.parse('12AB:0:0:CD30::/60', messages).value == '12ab:0:0:cd30::/60' - assert field.parse('12ab:0:0:cd30::/128', messages).value == '12ab:0:0:cd30::/128' + assert field.parse("12ab::cd30:0:0:0:0/60", messages).value == "12ab:0:0:cd30::/60" + assert field.parse("12AB:0:0:CD30::/60", messages).value == "12ab:0:0:cd30::/60" + assert field.parse("12ab:0:0:cd30::/128", messages).value == "12ab:0:0:cd30::/128" assert not messages.errors() assert messages.infos() == [ - 'Address prefix 12AB:0000:0000:CD30:0000:0000:0000:0000/60 was reformatted as 12ab:0:0:cd30::/60', - 'Address prefix 12ab::cd30:0:0:0:0/60 was reformatted as 12ab:0:0:cd30::/60', - 'Address prefix 12AB:0:0:CD30::/60 was reformatted as 12ab:0:0:cd30::/60', + "Address prefix 12AB:0000:0000:CD30:0000:0000:0000:0000/60 was reformatted as 12ab:0:0:cd30::/60", + "Address prefix 12ab::cd30:0:0:0:0/60 was reformatted as 12ab:0:0:cd30::/60", + "Address prefix 12AB:0:0:CD30::/60 was reformatted as 12ab:0:0:cd30::/60", ] - assert_validation_err('Invalid address prefix', field.parse, 'foo') - assert_validation_err('Invalid address prefix', field.parse, 'foo/bar') - assert_validation_err('invalid hexlet', field.parse, '2001525::/32') - assert_validation_err('should have 8 hextets', field.parse, '12AB:0:0:CD3/60') - assert_validation_err('Invalid address prefix', field.parse, '12AB::CD30/60') - assert_validation_err('Invalid address prefix', field.parse, '12AB::CD3/60') - assert_validation_err('Invalid address prefix', field.parse, '192.0.2.0/16') + assert_validation_err("Invalid address prefix", field.parse, "foo") + assert_validation_err("Invalid address prefix", field.parse, "foo/bar") + assert_validation_err("invalid hexlet", field.parse, "2001525::/32") + assert_validation_err("should have 8 hextets", field.parse, "12AB:0:0:CD3/60") + assert_validation_err("Invalid address prefix", field.parse, "12AB::CD30/60") + assert_validation_err("Invalid address prefix", field.parse, "12AB::CD3/60") + assert_validation_err("Invalid address prefix", field.parse, "192.0.2.0/16") def test_ipv6_prefixes_field(): field = RPSLIPv6PrefixesField() messages = RPSLParserMessages() - assert field.parse('12AB:0:0:CD30::/60', messages).value == '12ab:0:0:cd30::/60' - assert field.parse('12AB:0:0:CD30::/60, 2001:DB8::0/64', messages).value == '12ab:0:0:cd30::/60,2001:db8::/64' + assert field.parse("12AB:0:0:CD30::/60", messages).value == "12ab:0:0:cd30::/60" + assert ( + field.parse("12AB:0:0:CD30::/60, 2001:DB8::0/64", messages).value + == "12ab:0:0:cd30::/60,2001:db8::/64" + ) assert not messages.errors() - assert_validation_err('Invalid address prefix', field.parse, 'foo') - assert_validation_err('Invalid address prefix', field.parse, 'foo/bar') - assert_validation_err('invalid hexlet', field.parse, '2001:db8::/32, 2001525::/32') - assert_validation_err('should have 8 hextets', field.parse, '12AB:0:0:CD3/60') - assert_validation_err('Invalid address prefix', field.parse, '12AB::CD30/60') - assert_validation_err('Invalid address prefix', field.parse, '12AB::CD3/60') - assert_validation_err('Invalid address prefix', field.parse, '192.0.2.0/16') + assert_validation_err("Invalid address prefix", field.parse, "foo") + assert_validation_err("Invalid address prefix", field.parse, "foo/bar") + assert_validation_err("invalid hexlet", field.parse, "2001:db8::/32, 2001525::/32") + assert_validation_err("should have 8 hextets", field.parse, "12AB:0:0:CD3/60") + assert_validation_err("Invalid address prefix", field.parse, "12AB::CD30/60") + assert_validation_err("Invalid address prefix", field.parse, "12AB::CD3/60") + assert_validation_err("Invalid address prefix", field.parse, "192.0.2.0/16") def test_ipv4_address_range_field(): field = RPSLIPv4AddressRangeField() messages = RPSLParserMessages() - parse_result = field.parse('192.0.02.0', messages) - assert parse_result.value == '192.0.2.0' - assert parse_result.ip_first == IP('192.0.2.0') - assert parse_result.ip_last == IP('192.0.2.0') + parse_result = field.parse("192.0.02.0", messages) + assert parse_result.value == "192.0.2.0" + assert parse_result.ip_first == IP("192.0.2.0") + assert parse_result.ip_last == IP("192.0.2.0") - parse_result = field.parse('192.0.2.0 - 192.0.2.126', messages) - assert parse_result.value == '192.0.2.0 - 192.0.2.126' + parse_result = field.parse("192.0.2.0 - 192.0.2.126", messages) + assert parse_result.value == "192.0.2.0 - 192.0.2.126" - parse_result = field.parse('192.0.2.0 -192.0.02.126', messages) - assert parse_result.value == '192.0.2.0 - 192.0.2.126' - assert parse_result.ip_first == IP('192.0.2.0') - assert parse_result.ip_last == IP('192.0.2.126') + parse_result = field.parse("192.0.2.0 -192.0.02.126", messages) + assert parse_result.value == "192.0.2.0 - 192.0.2.126" + assert parse_result.ip_first == IP("192.0.2.0") + assert parse_result.ip_last == IP("192.0.2.126") assert not messages.errors() assert messages.infos() == [ - 'Address range 192.0.02.0 was reformatted as 192.0.2.0', - 'Address range 192.0.2.0 -192.0.02.126 was reformatted as 192.0.2.0 - 192.0.2.126', + "Address range 192.0.02.0 was reformatted as 192.0.2.0", + "Address range 192.0.2.0 -192.0.02.126 was reformatted as 192.0.2.0 - 192.0.2.126", ] - assert_validation_err('Invalid address', field.parse, '192.0.1.5555 - 192.0.2.0') - assert_validation_err('IP version mismatch', field.parse, '192.0.2.0 - 2001:db8::') - assert_validation_err('first IP is higher', field.parse, '192.0.2.1 - 192.0.2.0') - assert_validation_err('IP version mismatch', field.parse, '2001:db8::0 - 2001:db8::1') + assert_validation_err("Invalid address", field.parse, "192.0.1.5555 - 192.0.2.0") + assert_validation_err("IP version mismatch", field.parse, "192.0.2.0 - 2001:db8::") + assert_validation_err("first IP is higher", field.parse, "192.0.2.1 - 192.0.2.0") + assert_validation_err("IP version mismatch", field.parse, "2001:db8::0 - 2001:db8::1") def test_route_set_members_field(): @@ -161,261 +177,271 @@ def test_route_set_members_field(): field = RPSLRouteSetMemberField(ip_version=4) messages = RPSLParserMessages() - assert field.parse('192.0.2.0/24^24-25', messages).value == '192.0.2.0/24^24-25' - assert field.parse('AS065537:RS-TEST^32', messages).value == 'AS65537:RS-TEST^32' - assert field.parse('AS065537^32', messages).value == 'AS65537^32' - assert field.parse('192.0.2.0/25^+', messages).value == '192.0.2.0/25^+' - assert field.parse('192.0.2.0/25^32', messages).value == '192.0.2.0/25^32' - assert field.parse('192.00.02.0/25^-', messages).value == '192.0.2.0/25^-' - assert field.parse('192.0.02.0/32', messages).value == '192.0.2.0/32' + assert field.parse("192.0.2.0/24^24-25", messages).value == "192.0.2.0/24^24-25" + assert field.parse("AS065537:RS-TEST^32", messages).value == "AS65537:RS-TEST^32" + assert field.parse("AS065537^32", messages).value == "AS65537^32" + assert field.parse("192.0.2.0/25^+", messages).value == "192.0.2.0/25^+" + assert field.parse("192.0.2.0/25^32", messages).value == "192.0.2.0/25^32" + assert field.parse("192.00.02.0/25^-", messages).value == "192.0.2.0/25^-" + assert field.parse("192.0.02.0/32", messages).value == "192.0.2.0/32" assert not messages.errors() assert messages.infos() == [ - 'Route set member AS065537:RS-TEST^32 was reformatted as AS65537:RS-TEST^32', - 'Route set member AS065537^32 was reformatted as AS65537^32', - 'Route set member 192.00.02.0/25^- was reformatted as 192.0.2.0/25^-', - 'Route set member 192.0.02.0/32 was reformatted as 192.0.2.0/32', + "Route set member AS065537:RS-TEST^32 was reformatted as AS65537:RS-TEST^32", + "Route set member AS065537^32 was reformatted as AS65537^32", + "Route set member 192.00.02.0/25^- was reformatted as 192.0.2.0/25^-", + "Route set member 192.0.02.0/32 was reformatted as 192.0.2.0/32", ] - assert_validation_err('Value is neither a valid set name nor a valid prefix', field.parse, 'AS65537:TEST') - assert_validation_err('Missing range operator', field.parse, '192.0.2.0/24^') - assert_validation_err('Invalid range operator', field.parse, '192.0.2.0/24^x') - assert_validation_err('Invalid range operator', field.parse, '192.0.2.0/24^-32') - assert_validation_err('Invalid range operator', field.parse, '192.0.2.0/24^32-') - assert_validation_err('Invalid range operator', field.parse, '192.0.2.0/24^24+32') - assert_validation_err('operator length (23) must be equal ', field.parse, '192.0.2.0/24^23') - assert_validation_err('operator start (23) must be equal ', field.parse, '192.0.2.0/24^23-32') - assert_validation_err('operator end (30) must be equal', field.parse, '192.0.2.0/24^32-30') + assert_validation_err("Value is neither a valid set name nor a valid prefix", field.parse, "AS65537:TEST") + assert_validation_err("Missing range operator", field.parse, "192.0.2.0/24^") + assert_validation_err("Invalid range operator", field.parse, "192.0.2.0/24^x") + assert_validation_err("Invalid range operator", field.parse, "192.0.2.0/24^-32") + assert_validation_err("Invalid range operator", field.parse, "192.0.2.0/24^32-") + assert_validation_err("Invalid range operator", field.parse, "192.0.2.0/24^24+32") + assert_validation_err("operator length (23) must be equal ", field.parse, "192.0.2.0/24^23") + assert_validation_err("operator start (23) must be equal ", field.parse, "192.0.2.0/24^23-32") + assert_validation_err("operator end (30) must be equal", field.parse, "192.0.2.0/24^32-30") field = RPSLRouteSetMemberField(ip_version=None) messages = RPSLParserMessages() - assert field.parse('192.0.2.0/24^24-25', messages).value == '192.0.2.0/24^24-25' - assert field.parse('12ab:0:0:cd30::/128', messages).value == '12ab:0:0:cd30::/128' - assert field.parse('12ab:0:0:cd30::/64^120-128', messages).value == '12ab:0:0:cd30::/64^120-128' - assert field.parse('AS65537:RS-TEST', messages).value == 'AS65537:RS-TEST' + assert field.parse("192.0.2.0/24^24-25", messages).value == "192.0.2.0/24^24-25" + assert field.parse("12ab:0:0:cd30::/128", messages).value == "12ab:0:0:cd30::/128" + assert field.parse("12ab:0:0:cd30::/64^120-128", messages).value == "12ab:0:0:cd30::/64^120-128" + assert field.parse("AS65537:RS-TEST", messages).value == "AS65537:RS-TEST" - assert field.parse('192.0.2.0/25^+', messages).value == '192.0.2.0/25^+' - assert field.parse('192.0.2.0/25^32', messages).value == '192.0.2.0/25^32' - assert field.parse('12ab:00:0:cd30::/60^-', messages).value == '12ab:0:0:cd30::/60^-' - assert field.parse('12ab:0:0:cd30::/60', messages).value == '12ab:0:0:cd30::/60' + assert field.parse("192.0.2.0/25^+", messages).value == "192.0.2.0/25^+" + assert field.parse("192.0.2.0/25^32", messages).value == "192.0.2.0/25^32" + assert field.parse("12ab:00:0:cd30::/60^-", messages).value == "12ab:0:0:cd30::/60^-" + assert field.parse("12ab:0:0:cd30::/60", messages).value == "12ab:0:0:cd30::/60" assert not messages.errors() assert messages.infos() == [ - 'Route set member 12ab:00:0:cd30::/60^- was reformatted as 12ab:0:0:cd30::/60^-', + "Route set member 12ab:00:0:cd30::/60^- was reformatted as 12ab:0:0:cd30::/60^-", ] - assert_validation_err('Invalid range operator', field.parse, '192.0.2.0/32^24+32') - assert_validation_err('Invalid range operator', field.parse, '12ab:0:0:cd30::/60^24+32') + assert_validation_err("Invalid range operator", field.parse, "192.0.2.0/32^24+32") + assert_validation_err("Invalid range operator", field.parse, "12ab:0:0:cd30::/60^24+32") def test_validate_as_number_field(): field = RPSLASNumberField() messages = RPSLParserMessages() - parse_result = field.parse('AS065537', messages) - assert parse_result.value == 'AS65537' + parse_result = field.parse("AS065537", messages) + assert parse_result.value == "AS65537" assert parse_result.asn_first == 65537 assert parse_result.asn_last == 65537 assert not messages.errors() - assert messages.infos() == ['AS number AS065537 was reformatted as AS65537'] + assert messages.infos() == ["AS number AS065537 was reformatted as AS65537"] - assert_validation_err('not numeric', field.parse, 'ASxxxx') - assert_validation_err('not numeric', field.parse, 'AS2345💩') - assert_validation_err('must start with', field.parse, '💩AS2345') + assert_validation_err("not numeric", field.parse, "ASxxxx") + assert_validation_err("not numeric", field.parse, "AS2345💩") + assert_validation_err("must start with", field.parse, "💩AS2345") def test_validate_as_block_field(): field = RPSLASBlockField() messages = RPSLParserMessages() - parse_result = field.parse('AS001- AS200', messages) - assert parse_result.value == 'AS1 - AS200' + parse_result = field.parse("AS001- AS200", messages) + assert parse_result.value == "AS1 - AS200" assert parse_result.asn_first == 1 assert parse_result.asn_last == 200 - assert field.parse('AS200-AS0200', messages).value == 'AS200 - AS200' + assert field.parse("AS200-AS0200", messages).value == "AS200 - AS200" assert not messages.errors() assert messages.infos() == [ - 'AS range AS001- AS200 was reformatted as AS1 - AS200', - 'AS range AS200-AS0200 was reformatted as AS200 - AS200' + "AS range AS001- AS200 was reformatted as AS1 - AS200", + "AS range AS200-AS0200 was reformatted as AS200 - AS200", ] - assert_validation_err('does not contain a hyphen', field.parse, 'AS65537') - assert_validation_err('number part is not numeric', field.parse, 'ASxxxx - ASyyyy') - assert_validation_err('Invalid AS number', field.parse, 'AS-FOO - AS-BAR') - assert_validation_err('Invalid AS range', field.parse, 'AS300 - AS200') + assert_validation_err("does not contain a hyphen", field.parse, "AS65537") + assert_validation_err("number part is not numeric", field.parse, "ASxxxx - ASyyyy") + assert_validation_err("Invalid AS number", field.parse, "AS-FOO - AS-BAR") + assert_validation_err("Invalid AS range", field.parse, "AS300 - AS200") def test_validate_set_name_field(): - field = RPSLSetNameField(prefix='AS') + field = RPSLSetNameField(prefix="AS") messages = RPSLParserMessages() - assert field.parse('AS-FOO', messages).value == 'AS-FOO' - assert field.parse('AS01:AS-FOO', messages).value == 'AS1:AS-FOO' - assert field.parse('AS1:AS-FOO:AS3', messages).value == 'AS1:AS-FOO:AS3' - assert field.parse('AS01:AS-3', messages).value == 'AS1:AS-3' + assert field.parse("AS-FOO", messages).value == "AS-FOO" + assert field.parse("AS01:AS-FOO", messages).value == "AS1:AS-FOO" + assert field.parse("AS1:AS-FOO:AS3", messages).value == "AS1:AS-FOO:AS3" + assert field.parse("AS01:AS-3", messages).value == "AS1:AS-3" assert not messages.errors() assert messages.infos() == [ - 'Set name AS01:AS-FOO was reformatted as AS1:AS-FOO', - 'Set name AS01:AS-3 was reformatted as AS1:AS-3' + "Set name AS01:AS-FOO was reformatted as AS1:AS-FOO", + "Set name AS01:AS-3 was reformatted as AS1:AS-3", ] - long_set = 'AS1:AS-B:AS-C:AS-D:AS-E:AS-F' - assert_validation_err('at least one component must be an actual set name', field.parse, 'AS1',) - assert_validation_err('at least one component must be an actual set name', field.parse, 'AS1:AS3') - assert_validation_err('not a valid AS number, nor does it start with AS-', field.parse, 'AS1:AS-FOO:RS-FORBIDDEN') - assert_validation_err('not a valid AS number nor a valid set name', field.parse, ':AS-FOO') - assert_validation_err('not a valid AS number nor a valid set name', field.parse, 'AS-FOO:') - assert_validation_err('can have a maximum of five components', field.parse, long_set) - assert_validation_err('reserved word', field.parse, 'AS1:AS-ANY') - - assert field.parse('AS-ANY', messages, strict_validation=False).value == 'AS-ANY' + long_set = "AS1:AS-B:AS-C:AS-D:AS-E:AS-F" + assert_validation_err( + "at least one component must be an actual set name", + field.parse, + "AS1", + ) + assert_validation_err("at least one component must be an actual set name", field.parse, "AS1:AS3") + assert_validation_err( + "not a valid AS number, nor does it start with AS-", field.parse, "AS1:AS-FOO:RS-FORBIDDEN" + ) + assert_validation_err("not a valid AS number nor a valid set name", field.parse, ":AS-FOO") + assert_validation_err("not a valid AS number nor a valid set name", field.parse, "AS-FOO:") + assert_validation_err("can have a maximum of five components", field.parse, long_set) + assert_validation_err("reserved word", field.parse, "AS1:AS-ANY") + + assert field.parse("AS-ANY", messages, strict_validation=False).value == "AS-ANY" assert field.parse(long_set, messages, strict_validation=False).value == long_set - field = RPSLSetNameField(prefix='RS') + field = RPSLSetNameField(prefix="RS") messages = RPSLParserMessages() - assert field.parse('RS-FOO', messages).value == 'RS-FOO' - assert field.parse('AS1:RS-FOO', messages).value == 'AS1:RS-FOO' - assert field.parse('AS1:RS-FOO:AS3', messages).value == 'AS1:RS-FOO:AS3' - assert field.parse('AS1:RS-3', messages).value == 'AS1:RS-3' + assert field.parse("RS-FOO", messages).value == "RS-FOO" + assert field.parse("AS1:RS-FOO", messages).value == "AS1:RS-FOO" + assert field.parse("AS1:RS-FOO:AS3", messages).value == "AS1:RS-FOO:AS3" + assert field.parse("AS1:RS-3", messages).value == "AS1:RS-3" assert not messages.errors() - assert_validation_err('at least one component must be an actual set name', field.parse, 'AS1:AS-FOO') + assert_validation_err("at least one component must be an actual set name", field.parse, "AS1:AS-FOO") def test_validate_email_field(): field = RPSLEmailField() messages = RPSLParserMessages() - assert field.parse('foo.bar@example.asia', messages).value == 'foo.bar@example.asia' - assert field.parse('foo.bar@[192.0.2.1]', messages).value == 'foo.bar@[192.0.2.1]' - assert field.parse('foo.bar@[2001:db8::1]', messages).value == 'foo.bar@[2001:db8::1]' + assert field.parse("foo.bar@example.asia", messages).value == "foo.bar@example.asia" + assert field.parse("foo.bar@[192.0.2.1]", messages).value == "foo.bar@[192.0.2.1]" + assert field.parse("foo.bar@[2001:db8::1]", messages).value == "foo.bar@[2001:db8::1]" assert not messages.errors() - assert_validation_err('Invalid e-mail', field.parse, 'foo.bar+baz@') - assert_validation_err('Invalid e-mail', field.parse, 'a§§@example.com') - assert_validation_err('Invalid e-mail', field.parse, 'a@[192.0.2.2.2]') + assert_validation_err("Invalid e-mail", field.parse, "foo.bar+baz@") + assert_validation_err("Invalid e-mail", field.parse, "a§§@example.com") + assert_validation_err("Invalid e-mail", field.parse, "a@[192.0.2.2.2]") def test_validate_changed_field(): field = RPSLChangedField() messages = RPSLParserMessages() - assert field.parse('foo.bar@example.asia', messages).value == 'foo.bar@example.asia' - assert field.parse('foo.bar@[192.0.2.1] 20190701', messages).value == 'foo.bar@[192.0.2.1] 20190701' - assert field.parse('foo.bar@[2001:db8::1] 19980101', messages).value == 'foo.bar@[2001:db8::1] 19980101' + assert field.parse("foo.bar@example.asia", messages).value == "foo.bar@example.asia" + assert field.parse("foo.bar@[192.0.2.1] 20190701", messages).value == "foo.bar@[192.0.2.1] 20190701" + assert field.parse("foo.bar@[2001:db8::1] 19980101", messages).value == "foo.bar@[2001:db8::1] 19980101" assert not messages.errors() - assert_validation_err('Invalid e-mail', field.parse, 'foo.bar+baz@') - assert_validation_err('Invalid changed date', field.parse, 'foo.bar@example.com 20191301') - assert_validation_err('Invalid e-mail', field.parse, '\nfoo.bar@example.com \n20190701') - assert_validation_err('Invalid changed date', field.parse, 'foo.bar@example.com \n20190701') - assert_validation_err('Invalid changed date', field.parse, 'foo.bar@example.com 20190701\n') + assert_validation_err("Invalid e-mail", field.parse, "foo.bar+baz@") + assert_validation_err("Invalid changed date", field.parse, "foo.bar@example.com 20191301") + assert_validation_err("Invalid e-mail", field.parse, "\nfoo.bar@example.com \n20190701") + assert_validation_err("Invalid changed date", field.parse, "foo.bar@example.com \n20190701") + assert_validation_err("Invalid changed date", field.parse, "foo.bar@example.com 20190701\n") def test_validate_dns_name_field(): field = RPSLDNSNameField() messages = RPSLParserMessages() - assert field.parse('foo.bar.baz', messages).value == 'foo.bar.baz' + assert field.parse("foo.bar.baz", messages).value == "foo.bar.baz" assert not messages.errors() - assert_validation_err('Invalid DNS name', field.parse, 'foo.bar+baz@') + assert_validation_err("Invalid DNS name", field.parse, "foo.bar+baz@") def test_validate_url_field(): field = RPSLURLField() messages = RPSLParserMessages() - assert field.parse('http://example.com', messages).value == 'http://example.com' - assert field.parse('https://example.com', messages).value == 'https://example.com' + assert field.parse("http://example.com", messages).value == "http://example.com" + assert field.parse("https://example.com", messages).value == "https://example.com" assert not messages.errors() - assert_validation_err('Invalid http/https URL', field.parse, 'ftp://test') - assert_validation_err('Invalid http/https URL', field.parse, 'test') - assert_validation_err('Invalid http/https URL', field.parse, 'test') + assert_validation_err("Invalid http/https URL", field.parse, "ftp://test") + assert_validation_err("Invalid http/https URL", field.parse, "test") + assert_validation_err("Invalid http/https URL", field.parse, "test") def test_validate_generic_name_field(): field = RPSLGenericNameField() messages = RPSLParserMessages() - assert field.parse('MAINT-FOO', messages).value == 'MAINT-FOO' - assert field.parse('FOO-MNT', messages).value == 'FOO-MNT' - assert field.parse('FOO-MN_T2', messages).value == 'FOO-MN_T2' + assert field.parse("MAINT-FOO", messages).value == "MAINT-FOO" + assert field.parse("FOO-MNT", messages).value == "FOO-MNT" + assert field.parse("FOO-MN_T2", messages).value == "FOO-MN_T2" assert not messages.errors() - assert_validation_err('reserved word', field.parse, 'any') - assert_validation_err('reserved prefix', field.parse, 'As-FOO') - assert_validation_err('invalid character', field.parse, 'FoO$BAR') - assert_validation_err('invalid character', field.parse, 'FOOBAR-') - assert_validation_err('invalid character', field.parse, 'FOO💩BAR') + assert_validation_err("reserved word", field.parse, "any") + assert_validation_err("reserved prefix", field.parse, "As-FOO") + assert_validation_err("invalid character", field.parse, "FoO$BAR") + assert_validation_err("invalid character", field.parse, "FOOBAR-") + assert_validation_err("invalid character", field.parse, "FOO💩BAR") - assert field.parse('AS-FOO', messages, strict_validation=False).value == 'AS-FOO' - assert field.parse('FOO BAR', messages, strict_validation=False) is None + assert field.parse("AS-FOO", messages, strict_validation=False).value == "AS-FOO" + assert field.parse("FOO BAR", messages, strict_validation=False) is None - field = RPSLGenericNameField(allowed_prefixes=['as']) + field = RPSLGenericNameField(allowed_prefixes=["as"]) messages = RPSLParserMessages() - assert field.parse('As-FOO', messages).value == 'As-FOO' + assert field.parse("As-FOO", messages).value == "As-FOO" assert not messages.errors() - assert_validation_err('reserved prefix', field.parse, 'FLTr-FOO') + assert_validation_err("reserved prefix", field.parse, "FLTr-FOO") field = RPSLGenericNameField(non_strict_allow_any=True) - assert field.parse('FOO BAR', messages, strict_validation=False).value == 'FOO BAR' - assert_validation_err('invalid character', field.parse, 'FOO BAR') + assert field.parse("FOO BAR", messages, strict_validation=False).value == "FOO BAR" + assert_validation_err("invalid character", field.parse, "FOO BAR") def test_rpsl_reference_field(): - field = RPSLReferenceField(referring=['person']) + field = RPSLReferenceField(referring=["person"]) messages = RPSLParserMessages() - assert field.parse('SR123-NTT', messages).value == 'SR123-NTT' + assert field.parse("SR123-NTT", messages).value == "SR123-NTT" assert not messages.errors() - assert_validation_err('RS- is a reserved prefix', field.parse, 'RS-1234') - assert_validation_err('Invalid name', field.parse, 'foo$$') + assert_validation_err("RS- is a reserved prefix", field.parse, "RS-1234") + assert_validation_err("Invalid name", field.parse, "foo$$") - field = RPSLReferenceField(referring=['aut-num', 'as-set']) + field = RPSLReferenceField(referring=["aut-num", "as-set"]) messages = RPSLParserMessages() - assert field.parse('AS01234', messages).value == 'AS1234' - assert field.parse('AS-FOO', messages).value == 'AS-FOO' + assert field.parse("AS01234", messages).value == "AS1234" + assert field.parse("AS-FOO", messages).value == "AS-FOO" assert not messages.errors() - assert_validation_err(['Invalid AS number', 'start with AS-'], field.parse, 'RS-1234') - assert_validation_err(['Invalid AS number', 'start with AS-'], field.parse, 'RS-1234') - assert_validation_err(['Invalid AS number', 'at least one component must be an actual set name (i.e. start with AS-'], field.parse, 'FOOBAR') + assert_validation_err(["Invalid AS number", "start with AS-"], field.parse, "RS-1234") + assert_validation_err(["Invalid AS number", "start with AS-"], field.parse, "RS-1234") + assert_validation_err( + ["Invalid AS number", "at least one component must be an actual set name (i.e. start with AS-"], + field.parse, + "FOOBAR", + ) def test_rpsl_references_field(): - field = RPSLReferenceListField(referring=['aut-num']) + field = RPSLReferenceListField(referring=["aut-num"]) messages = RPSLParserMessages() - assert field.parse('AS1234', messages).value == 'AS1234' - assert field.parse('AS01234, AS04567', messages).value == 'AS1234,AS4567' + assert field.parse("AS1234", messages).value == "AS1234" + assert field.parse("AS01234, AS04567", messages).value == "AS1234,AS4567" assert not messages.errors() - assert_validation_err('Invalid AS number', field.parse, 'ANY') + assert_validation_err("Invalid AS number", field.parse, "ANY") - field = RPSLReferenceListField(referring=['aut-num'], allow_kw_any=True) + field = RPSLReferenceListField(referring=["aut-num"], allow_kw_any=True) messages = RPSLParserMessages() - assert field.parse('AS1234', messages).value == 'AS1234' - assert field.parse('AS01234, AS04567', messages).value == 'AS1234,AS4567' - assert field.parse('any', messages).value == 'ANY' + assert field.parse("AS1234", messages).value == "AS1234" + assert field.parse("AS01234, AS04567", messages).value == "AS1234,AS4567" + assert field.parse("any", messages).value == "ANY" assert not messages.errors() - assert_validation_err('Invalid AS number', field.parse, 'AS1234, any') + assert_validation_err("Invalid AS number", field.parse, "AS1234, any") def test_rpsl_auth_field(config_override): field = RPSLAuthField() messages = RPSLParserMessages() - assert field.parse('MD5-pw hashhash', messages).value == 'MD5-pw hashhash' - assert field.parse('bcrypt-pw hashhash', messages).value == 'bcrypt-pw hashhash' - assert field.parse('PGPKEY-AABB0011', messages).value == 'PGPKEY-AABB0011' + assert field.parse("MD5-pw hashhash", messages).value == "MD5-pw hashhash" + assert field.parse("bcrypt-pw hashhash", messages).value == "bcrypt-pw hashhash" + assert field.parse("PGPKEY-AABB0011", messages).value == "PGPKEY-AABB0011" assert not messages.errors() - assert_validation_err('Invalid auth attribute', field.parse, 'PGPKEY-XX') - assert_validation_err('Invalid auth attribute', field.parse, 'PGPKEY-AABB00112233') - assert_validation_err('Invalid auth attribute', field.parse, 'ARGON-PW hashhash') - assert_validation_err('Invalid auth attribute', field.parse, 'BCRYPT-PWhashhash') + assert_validation_err("Invalid auth attribute", field.parse, "PGPKEY-XX") + assert_validation_err("Invalid auth attribute", field.parse, "PGPKEY-AABB00112233") + assert_validation_err("Invalid auth attribute", field.parse, "ARGON-PW hashhash") + assert_validation_err("Invalid auth attribute", field.parse, "BCRYPT-PWhashhash") - assert_validation_err('Invalid auth attribute', field.parse, 'CRYPT-PW hashhash') - assert field.parse('CRYPT-PW hashhash', messages, strict_validation=False).value == 'CRYPT-PW hashhash' + assert_validation_err("Invalid auth attribute", field.parse, "CRYPT-PW hashhash") + assert field.parse("CRYPT-PW hashhash", messages, strict_validation=False).value == "CRYPT-PW hashhash" - config_override({'auth': {'password_hashers': {'crypt-pw': 'enabled'}}}) - assert field.parse('CRYPT-PW hashhash', messages).value == 'CRYPT-PW hashhash' + config_override({"auth": {"password_hashers": {"crypt-pw": "enabled"}}}) + assert field.parse("CRYPT-PW hashhash", messages).value == "CRYPT-PW hashhash" - config_override({'auth': {'password_hashers': {'crypt-pw': 'disabled'}}}) - assert field.parse('CRYPT-PW hashhash', messages, strict_validation=False) is None + config_override({"auth": {"password_hashers": {"crypt-pw": "disabled"}}}) + assert field.parse("CRYPT-PW hashhash", messages, strict_validation=False) is None diff --git a/irrd/rpsl/tests/test_rpsl_objects.py b/irrd/rpsl/tests/test_rpsl_objects.py index d47c36993..a432eb574 100644 --- a/irrd/rpsl/tests/test_rpsl_objects.py +++ b/irrd/rpsl/tests/test_rpsl_objects.py @@ -5,26 +5,50 @@ from pytest import raises from pytz import timezone -from irrd.conf import PASSWORD_HASH_DUMMY_VALUE, AUTH_SET_CREATION_COMMON_KEY -from irrd.utils.rpsl_samples import (object_sample_mapping, SAMPLE_MALFORMED_EMPTY_LINE, - SAMPLE_MALFORMED_ATTRIBUTE_NAME, - SAMPLE_UNKNOWN_CLASS, SAMPLE_MISSING_MANDATORY_ATTRIBUTE, - SAMPLE_MALFORMED_SOURCE, - SAMPLE_MALFORMED_PK, SAMPLE_UNKNOWN_ATTRIBUTE, - SAMPLE_INVALID_MULTIPLE_ATTRIBUTE, - KEY_CERT_SIGNED_MESSAGE_VALID, KEY_CERT_SIGNED_MESSAGE_INVALID, - KEY_CERT_SIGNED_MESSAGE_CORRUPT, - KEY_CERT_SIGNED_MESSAGE_WRONG_KEY, - TEMPLATE_ROUTE_OBJECT, - TEMPLATE_PERSON_OBJECT, - SAMPLE_LINE_NEITHER_CONTINUATION_NOR_ATTR, - SAMPLE_MISSING_SOURCE, SAMPLE_ROUTE) +from irrd.conf import AUTH_SET_CREATION_COMMON_KEY, PASSWORD_HASH_DUMMY_VALUE +from irrd.utils.rpsl_samples import ( + KEY_CERT_SIGNED_MESSAGE_CORRUPT, + KEY_CERT_SIGNED_MESSAGE_INVALID, + KEY_CERT_SIGNED_MESSAGE_VALID, + KEY_CERT_SIGNED_MESSAGE_WRONG_KEY, + SAMPLE_INVALID_MULTIPLE_ATTRIBUTE, + SAMPLE_LINE_NEITHER_CONTINUATION_NOR_ATTR, + SAMPLE_MALFORMED_ATTRIBUTE_NAME, + SAMPLE_MALFORMED_EMPTY_LINE, + SAMPLE_MALFORMED_PK, + SAMPLE_MALFORMED_SOURCE, + SAMPLE_MISSING_MANDATORY_ATTRIBUTE, + SAMPLE_MISSING_SOURCE, + SAMPLE_ROUTE, + SAMPLE_UNKNOWN_ATTRIBUTE, + SAMPLE_UNKNOWN_CLASS, + TEMPLATE_PERSON_OBJECT, + TEMPLATE_ROUTE_OBJECT, + object_sample_mapping, +) from ..parser import UnknownRPSLObjectClassException -from ..rpsl_objects import (RPSLAsBlock, RPSLAsSet, RPSLAutNum, RPSLDomain, RPSLFilterSet, RPSLInetRtr, - RPSLInet6Num, RPSLInetnum, RPSLKeyCert, RPSLMntner, RPSLPeeringSet, - RPSLPerson, RPSLRole, RPSLRoute, RPSLRouteSet, RPSLRoute6, RPSLRtrSet, - OBJECT_CLASS_MAPPING, rpsl_object_from_text) +from ..rpsl_objects import ( + OBJECT_CLASS_MAPPING, + RPSLAsBlock, + RPSLAsSet, + RPSLAutNum, + RPSLDomain, + RPSLFilterSet, + RPSLInet6Num, + RPSLInetnum, + RPSLInetRtr, + RPSLKeyCert, + RPSLMntner, + RPSLPeeringSet, + RPSLPerson, + RPSLRole, + RPSLRoute, + RPSLRoute6, + RPSLRouteSet, + RPSLRtrSet, + rpsl_object_from_text, +) class TestRPSLParsingGeneric: @@ -32,77 +56,77 @@ class TestRPSLParsingGeneric: def test_unknown_class(self): with raises(UnknownRPSLObjectClassException) as ve: rpsl_object_from_text(SAMPLE_UNKNOWN_CLASS) - assert 'unknown object class' in str(ve.value) + assert "unknown object class" in str(ve.value) def test_malformed_empty_line(self): obj = rpsl_object_from_text(SAMPLE_MALFORMED_EMPTY_LINE, strict_validation=False) - assert len(obj.messages.errors()) == 1, f'Unexpected extra errors: {obj.messages.errors()}' - assert 'encountered empty line' in obj.messages.errors()[0] + assert len(obj.messages.errors()) == 1, f"Unexpected extra errors: {obj.messages.errors()}" + assert "encountered empty line" in obj.messages.errors()[0] with raises(ValueError): obj.source() def test_malformed_attribute_name(self): obj = rpsl_object_from_text(SAMPLE_MALFORMED_ATTRIBUTE_NAME, strict_validation=False) - assert len(obj.messages.errors()) == 1, f'Unexpected extra errors: {obj.messages.errors()}' - assert 'malformed attribute name' in obj.messages.errors()[0] + assert len(obj.messages.errors()) == 1, f"Unexpected extra errors: {obj.messages.errors()}" + assert "malformed attribute name" in obj.messages.errors()[0] def test_missing_mandatory_attribute(self): obj = rpsl_object_from_text(SAMPLE_MISSING_MANDATORY_ATTRIBUTE, strict_validation=True) - assert len(obj.messages.errors()) == 1, f'Unexpected extra errors: {obj.messages.errors()}' + assert len(obj.messages.errors()) == 1, f"Unexpected extra errors: {obj.messages.errors()}" assert 'Mandatory attribute "mnt-by" on object route is missing' in obj.messages.errors()[0] obj = rpsl_object_from_text(SAMPLE_MISSING_MANDATORY_ATTRIBUTE, strict_validation=False) - assert len(obj.messages.errors()) == 0, f'Unexpected extra errors: {obj.messages.errors()}' + assert len(obj.messages.errors()) == 0, f"Unexpected extra errors: {obj.messages.errors()}" def test_unknown_atribute(self): obj = rpsl_object_from_text(SAMPLE_UNKNOWN_ATTRIBUTE, strict_validation=True) - assert len(obj.messages.errors()) == 1, f'Unexpected extra errors: {obj.messages.errors()}' - assert 'Unrecognised attribute' in obj.messages.errors()[0] + assert len(obj.messages.errors()) == 1, f"Unexpected extra errors: {obj.messages.errors()}" + assert "Unrecognised attribute" in obj.messages.errors()[0] obj = rpsl_object_from_text(SAMPLE_UNKNOWN_ATTRIBUTE, strict_validation=False) - assert len(obj.messages.errors()) == 0, f'Unexpected extra errors: {obj.messages.errors()}' + assert len(obj.messages.errors()) == 0, f"Unexpected extra errors: {obj.messages.errors()}" def test_invalid_multiple_attribute(self): obj = rpsl_object_from_text(SAMPLE_INVALID_MULTIPLE_ATTRIBUTE, strict_validation=True) - assert len(obj.messages.errors()) == 1, f'Unexpected extra errors: {obj.messages.errors()}' - assert 'occurs multiple times' in obj.messages.errors()[0] + assert len(obj.messages.errors()) == 1, f"Unexpected extra errors: {obj.messages.errors()}" + assert "occurs multiple times" in obj.messages.errors()[0] obj = rpsl_object_from_text(SAMPLE_INVALID_MULTIPLE_ATTRIBUTE, strict_validation=False) - assert len(obj.messages.errors()) == 0, f'Unexpected extra errors: {obj.messages.errors()}' + assert len(obj.messages.errors()) == 0, f"Unexpected extra errors: {obj.messages.errors()}" def test_malformed_pk(self): obj = rpsl_object_from_text(SAMPLE_MALFORMED_PK, strict_validation=True) - assert len(obj.messages.errors()) == 1, f'Unexpected extra errors: {obj.messages.errors()}' - assert 'Invalid address prefix: not-a-prefix' in obj.messages.errors()[0] + assert len(obj.messages.errors()) == 1, f"Unexpected extra errors: {obj.messages.errors()}" + assert "Invalid address prefix: not-a-prefix" in obj.messages.errors()[0] # A primary key field should also be tested in non-strict mode obj = rpsl_object_from_text(SAMPLE_MALFORMED_PK, strict_validation=False) - assert len(obj.messages.errors()) == 1, f'Unexpected extra errors: {obj.messages.errors()}' - assert 'Invalid address prefix: not-a-prefix' in obj.messages.errors()[0] + assert len(obj.messages.errors()) == 1, f"Unexpected extra errors: {obj.messages.errors()}" + assert "Invalid address prefix: not-a-prefix" in obj.messages.errors()[0] def test_malformed_source(self): obj = rpsl_object_from_text(SAMPLE_MALFORMED_SOURCE, strict_validation=False) - assert len(obj.messages.errors()) == 1, f'Unexpected extra errors: {obj.messages.errors()}' - assert 'contains invalid characters' in obj.messages.errors()[0] + assert len(obj.messages.errors()) == 1, f"Unexpected extra errors: {obj.messages.errors()}" + assert "contains invalid characters" in obj.messages.errors()[0] def test_missing_source_optional_default_source(self): - obj = rpsl_object_from_text(SAMPLE_MISSING_SOURCE, strict_validation=False, default_source='TEST') - assert len(obj.messages.errors()) == 0, f'Unexpected extra errors: {obj.messages.errors()}' - assert obj.source() == 'TEST' + obj = rpsl_object_from_text(SAMPLE_MISSING_SOURCE, strict_validation=False, default_source="TEST") + assert len(obj.messages.errors()) == 0, f"Unexpected extra errors: {obj.messages.errors()}" + assert obj.source() == "TEST" obj = rpsl_object_from_text(SAMPLE_MISSING_SOURCE, strict_validation=False) - assert len(obj.messages.errors()) == 1, f'Unexpected extra errors: {obj.messages.errors()}' + assert len(obj.messages.errors()) == 1, f"Unexpected extra errors: {obj.messages.errors()}" assert 'attribute "source" on object route is missing' in obj.messages.errors()[0] def test_line_neither_continuation_nor_attribute(self): obj = rpsl_object_from_text(SAMPLE_LINE_NEITHER_CONTINUATION_NOR_ATTR, strict_validation=False) - assert len(obj.messages.errors()) == 1, f'Unexpected extra errors: {obj.messages.errors()}' - assert 'line is neither continuation nor valid attribute' in obj.messages.errors()[0] + assert len(obj.messages.errors()) == 1, f"Unexpected extra errors: {obj.messages.errors()}" + assert "line is neither continuation nor valid attribute" in obj.messages.errors()[0] def test_double_object_297(self): - obj = rpsl_object_from_text(SAMPLE_ROUTE + ' \n' + SAMPLE_ROUTE) - assert len(obj.messages.errors()) == 3, f'Unexpected extra errors: {obj.messages.errors()}' + obj = rpsl_object_from_text(SAMPLE_ROUTE + " \n" + SAMPLE_ROUTE) + assert len(obj.messages.errors()) == 3, f"Unexpected extra errors: {obj.messages.errors()}" assert 'Attribute "route" on object route occurs multiple times' in obj.messages.errors()[0] @@ -116,11 +140,11 @@ def test_parse(self): obj = rpsl_object_from_text(rpsl_text) assert obj.__class__ == RPSLAsBlock assert not obj.messages.errors() - assert obj.pk() == 'AS65536 - AS65538' + assert obj.pk() == "AS65536 - AS65538" assert obj.asn_first == 65536 assert obj.asn_last == 65538 # Field parsing will cause our object to look slightly different than the original, hence the replace() - assert obj.render_rpsl_text() == rpsl_text.replace('as065538', 'AS65538') + assert obj.render_rpsl_text() == rpsl_text.replace("as065538", "AS65538") class TestRPSLAsSet: @@ -134,38 +158,38 @@ def test_parse(self): assert obj.__class__ == RPSLAsSet assert obj.clean_for_create() assert not obj.messages.errors() - assert obj.pk() == 'AS65537:AS-SETTEST' + assert obj.pk() == "AS65537:AS-SETTEST" assert obj.referred_strong_objects() == [ - ('admin-c', ['role', 'person'], ['PERSON-TEST']), - ('tech-c', ['role', 'person'], ['PERSON-TEST']), - ('mnt-by', ['mntner'], ['TEST-MNT']) + ("admin-c", ["role", "person"], ["PERSON-TEST"]), + ("tech-c", ["role", "person"], ["PERSON-TEST"]), + ("mnt-by", ["mntner"], ["TEST-MNT"]), ] assert obj.references_strong_inbound() == set() - assert obj.source() == 'TEST' - assert obj.pk_asn_segment == 'AS65537' + assert obj.source() == "TEST" + assert obj.pk_asn_segment == "AS65537" - assert obj.parsed_data['members'] == ['AS65538', 'AS65539', 'AS65537', 'AS-OTHERSET'] + assert obj.parsed_data["members"] == ["AS65538", "AS65539", "AS65537", "AS-OTHERSET"] # Field parsing will cause our object to look slightly different than the original, hence the replace() - assert obj.render_rpsl_text() == rpsl_text.replace('AS65538, AS65539', 'AS65538,AS65539') + assert obj.render_rpsl_text() == rpsl_text.replace("AS65538, AS65539", "AS65538,AS65539") def test_clean_for_create(self, config_override): rpsl_text = object_sample_mapping[RPSLAsSet().rpsl_object_class] - rpsl_text = rpsl_text.replace('AS65537:AS-SETTEST', 'AS-SETTEST') + rpsl_text = rpsl_text.replace("AS65537:AS-SETTEST", "AS-SETTEST") obj = rpsl_object_from_text(rpsl_text) assert obj.__class__ == RPSLAsSet assert not obj.messages.errors() assert not obj.clean_for_create() assert not obj.pk_asn_segment - assert 'as-set names must be hierarchical and the first ' in obj.messages.errors()[0] + assert "as-set names must be hierarchical and the first " in obj.messages.errors()[0] - config_override({'auth': {'set_creation': {'as-set': {'prefix_required': False}}}}) + config_override({"auth": {"set_creation": {"as-set": {"prefix_required": False}}}}) obj = rpsl_object_from_text(rpsl_text) assert obj.clean_for_create() assert not obj.pk_asn_segment - config_override({'auth': {'set_creation': { - AUTH_SET_CREATION_COMMON_KEY: {'prefix_required': False} - }}}) + config_override( + {"auth": {"set_creation": {AUTH_SET_CREATION_COMMON_KEY: {"prefix_required": False}}}} + ) obj = rpsl_object_from_text(rpsl_text) assert obj.clean_for_create() assert not obj.pk_asn_segment @@ -181,13 +205,13 @@ def test_parse(self): obj = rpsl_object_from_text(rpsl_text) assert obj.__class__ == RPSLAutNum assert not obj.messages.errors() - assert obj.pk() == 'AS65537' + assert obj.pk() == "AS65537" assert obj.asn_first == 65537 assert obj.asn_last == 65537 assert obj.ip_version() is None assert obj.references_strong_inbound() == set() # Field parsing will cause our object to look slightly different than the original, hence the replace() - assert obj.render_rpsl_text() == rpsl_text.replace('as065537', 'AS65537') + assert obj.render_rpsl_text() == rpsl_text.replace("as065537", "AS65537") class TestRPSLDomain: @@ -200,8 +224,8 @@ def test_parse(self): obj = rpsl_object_from_text(rpsl_text) assert obj.__class__ == RPSLDomain assert not obj.messages.errors() - assert obj.pk() == '2.0.192.IN-ADDR.ARPA' - assert obj.parsed_data['source'] == 'TEST' + assert obj.pk() == "2.0.192.IN-ADDR.ARPA" + assert obj.parsed_data["source"] == "TEST" assert obj.render_rpsl_text() == rpsl_text @@ -215,8 +239,8 @@ def test_parse(self): obj = rpsl_object_from_text(rpsl_text) assert obj.__class__ == RPSLFilterSet assert not obj.messages.errors() - assert obj.pk() == 'FLTR-SETTEST' - assert obj.render_rpsl_text() == rpsl_text.replace('\t', '+') # #298 + assert obj.pk() == "FLTR-SETTEST" + assert obj.render_rpsl_text() == rpsl_text.replace("\t", "+") # #298 class TestRPSLInetRtr: @@ -229,8 +253,8 @@ def test_parse(self): obj = rpsl_object_from_text(rpsl_text) assert obj.__class__ == RPSLInetRtr assert not obj.messages.errors() - assert obj.pk() == 'RTR.EXAMPLE.COM' - assert obj.parsed_data['inet-rtr'] == 'RTR.EXAMPLE.COM' + assert obj.pk() == "RTR.EXAMPLE.COM" + assert obj.parsed_data["inet-rtr"] == "RTR.EXAMPLE.COM" assert obj.render_rpsl_text() == rpsl_text @@ -244,9 +268,9 @@ def test_parse(self): obj = rpsl_object_from_text(rpsl_text) assert obj.__class__ == RPSLInet6Num assert not obj.messages.errors() - assert obj.pk() == '2001:DB8::/48' - assert obj.ip_first == IP('2001:db8::') - assert obj.ip_last == IP('2001:db8::ffff:ffff:ffff:ffff:ffff') + assert obj.pk() == "2001:DB8::/48" + assert obj.ip_first == IP("2001:db8::") + assert obj.ip_last == IP("2001:db8::ffff:ffff:ffff:ffff:ffff") assert obj.ip_version() == 6 assert obj.render_rpsl_text() == rpsl_text @@ -261,13 +285,13 @@ def test_parse(self): obj = rpsl_object_from_text(rpsl_text) assert obj.__class__ == RPSLInetnum assert not obj.messages.errors() - assert obj.pk() == '192.0.2.0 - 192.0.2.255' - assert obj.ip_first == IP('192.0.2.0') - assert obj.ip_last == IP('192.0.2.255') + assert obj.pk() == "192.0.2.0 - 192.0.2.255" + assert obj.ip_first == IP("192.0.2.0") + assert obj.ip_last == IP("192.0.2.255") assert obj.ip_version() == 4 assert obj.references_strong_inbound() == set() # Field parsing will cause our object to look slightly different than the original, hence the replace() - assert obj.render_rpsl_text() == rpsl_text.replace('192.0.02.255', '192.0.2.255') + assert obj.render_rpsl_text() == rpsl_text.replace("192.0.02.255", "192.0.2.255") class TestRPSLKeyCert: @@ -276,54 +300,61 @@ class TestRPSLKeyCert: tests call the actual gpg binary, as the test has little value when gpg is mocked out. """ + def test_has_mapping(self): obj = RPSLKeyCert() assert OBJECT_CLASS_MAPPING[obj.rpsl_object_class] == obj.__class__ - @pytest.mark.usefixtures('tmp_gpg_dir') # noqa: F811 + @pytest.mark.usefixtures("tmp_gpg_dir") # noqa: F811 def test_parse_parse(self): rpsl_text = object_sample_mapping[RPSLKeyCert().rpsl_object_class] # Mangle the fingerprint/owner/method lines to ensure the parser correctly re-generates them - mangled_rpsl_text = rpsl_text.replace('8626 1D8D BEBD A4F5 4692 D64D A838 3BA7 80F2 38C6', 'fingerprint') - mangled_rpsl_text = mangled_rpsl_text.replace('sasha', 'foo').replace('method: PGP', 'method: test') - - expected_text = rpsl_text.replace(' \n', '+ \n') # #298 + mangled_rpsl_text = rpsl_text.replace( + "8626 1D8D BEBD A4F5 4692 D64D A838 3BA7 80F2 38C6", "fingerprint" + ) + mangled_rpsl_text = mangled_rpsl_text.replace("sasha", "foo").replace( + "method: PGP", "method: test" + ) + + expected_text = rpsl_text.replace(" \n", "+ \n") # #298 obj = rpsl_object_from_text(mangled_rpsl_text) assert obj.__class__ == RPSLKeyCert assert not obj.messages.errors() - assert obj.pk() == 'PGPKEY-80F238C6' + assert obj.pk() == "PGPKEY-80F238C6" assert obj.render_rpsl_text() == expected_text - assert obj.parsed_data['fingerpr'] == '8626 1D8D BEBD A4F5 4692 D64D A838 3BA7 80F2 38C6' + assert obj.parsed_data["fingerpr"] == "8626 1D8D BEBD A4F5 4692 D64D A838 3BA7 80F2 38C6" - @pytest.mark.usefixtures('tmp_gpg_dir') # noqa: F811 + @pytest.mark.usefixtures("tmp_gpg_dir") # noqa: F811 def test_parse_incorrect_object_name(self, tmp_gpg_dir): rpsl_text = object_sample_mapping[RPSLKeyCert().rpsl_object_class] - obj = rpsl_object_from_text(rpsl_text.replace('PGPKEY-80F238C6', 'PGPKEY-80F23816')) + obj = rpsl_object_from_text(rpsl_text.replace("PGPKEY-80F238C6", "PGPKEY-80F23816")) errors = obj.messages.errors() - assert len(errors) == 1, f'Unexpected multiple errors: {errors}' - assert 'does not match key fingerprint' in errors[0] + assert len(errors) == 1, f"Unexpected multiple errors: {errors}" + assert "does not match key fingerprint" in errors[0] - @pytest.mark.usefixtures('tmp_gpg_dir') # noqa: F811 + @pytest.mark.usefixtures("tmp_gpg_dir") # noqa: F811 def test_parse_missing_key(self, tmp_gpg_dir): rpsl_text = object_sample_mapping[RPSLKeyCert().rpsl_object_class] - obj = rpsl_object_from_text(rpsl_text.replace('certif:', 'remarks:'), strict_validation=True) + obj = rpsl_object_from_text(rpsl_text.replace("certif:", "remarks:"), strict_validation=True) errors = obj.messages.errors() - assert len(errors) == 1, f'Unexpected multiple errors: {errors}' + assert len(errors) == 1, f"Unexpected multiple errors: {errors}" assert 'Mandatory attribute "certif" on object key-cert is missing' in errors[0] - @pytest.mark.usefixtures('tmp_gpg_dir') # noqa: F811 + @pytest.mark.usefixtures("tmp_gpg_dir") # noqa: F811 def test_parse_invalid_key(self, tmp_gpg_dir): rpsl_text = object_sample_mapping[RPSLKeyCert().rpsl_object_class] - obj = rpsl_object_from_text(rpsl_text.replace('mQINBFnY7YoBEADH5ooPsoR9G', 'foo'), strict_validation=True) + obj = rpsl_object_from_text( + rpsl_text.replace("mQINBFnY7YoBEADH5ooPsoR9G", "foo"), strict_validation=True + ) errors = obj.messages.errors() - assert len(errors) == 1, f'Unexpected multiple errors: {errors}' - assert 'Unable to read public PGP key: key corrupt or multiple keys provided' in errors[0] + assert len(errors) == 1, f"Unexpected multiple errors: {errors}" + assert "Unable to read public PGP key: key corrupt or multiple keys provided" in errors[0] - @pytest.mark.usefixtures('tmp_gpg_dir') # noqa: F811 + @pytest.mark.usefixtures("tmp_gpg_dir") # noqa: F811 def test_verify(self, tmp_gpg_dir): rpsl_text = object_sample_mapping[RPSLKeyCert().rpsl_object_class] obj = rpsl_object_from_text(rpsl_text) @@ -340,39 +371,41 @@ def test_has_mapping(self): assert OBJECT_CLASS_MAPPING[obj.rpsl_object_class] == obj.__class__ def test_parse(self, config_override): - config_override({'auth': {'password_hashers': {'crypt-pw': 'enabled'}}}) + config_override({"auth": {"password_hashers": {"crypt-pw": "enabled"}}}) rpsl_text = object_sample_mapping[RPSLMntner().rpsl_object_class] obj = rpsl_object_from_text(rpsl_text) assert obj.__class__ == RPSLMntner assert not obj.messages.errors() - assert obj.pk() == 'TEST-MNT' - assert obj.parsed_data['mnt-by'] == ['TEST-MNT', 'OTHER1-MNT', 'OTHER2-MNT'] + assert obj.pk() == "TEST-MNT" + assert obj.parsed_data["mnt-by"] == ["TEST-MNT", "OTHER1-MNT", "OTHER2-MNT"] assert obj.render_rpsl_text() == rpsl_text - assert obj.references_strong_inbound() == {'mnt-by'} + assert obj.references_strong_inbound() == {"mnt-by"} def test_parse_invalid_partial_dummy_hash(self, config_override): - config_override({'auth': {'password_hashers': {'crypt-pw': 'enabled'}}}) + config_override({"auth": {"password_hashers": {"crypt-pw": "enabled"}}}) rpsl_text = object_sample_mapping[RPSLMntner().rpsl_object_class] - rpsl_text = rpsl_text.replace('LEuuhsBJNFV0Q', PASSWORD_HASH_DUMMY_VALUE) + rpsl_text = rpsl_text.replace("LEuuhsBJNFV0Q", PASSWORD_HASH_DUMMY_VALUE) obj = rpsl_object_from_text(rpsl_text) assert obj.__class__ == RPSLMntner assert obj.messages.errors() == [ - 'Either all password auth hashes in a submitted mntner must be dummy objects, or none.' + "Either all password auth hashes in a submitted mntner must be dummy objects, or none." ] - @pytest.mark.usefixtures('tmp_gpg_dir') # noqa: F811 + @pytest.mark.usefixtures("tmp_gpg_dir") # noqa: F811 def test_verify(self, tmp_gpg_dir): rpsl_text = object_sample_mapping[RPSLMntner().rpsl_object_class] # Unknown hashes and invalid hashes should simply be ignored. # Strict validation set to False to allow legacy mode for CRYPT-PW - obj = rpsl_object_from_text(rpsl_text + 'auth: UNKNOWN_HASH foo\nauth: MD5-PW 💩', strict_validation=False) - - assert obj.verify_auth(['crypt-password']) - assert obj.verify_auth(['md5-password']) - assert obj.verify_auth(['bcrypt-password']) - assert obj.verify_auth(['md5-password'], 'PGPKey-80F238C6') - assert not obj.verify_auth(['other-password']) + obj = rpsl_object_from_text( + rpsl_text + "auth: UNKNOWN_HASH foo\nauth: MD5-PW 💩", strict_validation=False + ) + + assert obj.verify_auth(["crypt-password"]) + assert obj.verify_auth(["md5-password"]) + assert obj.verify_auth(["bcrypt-password"]) + assert obj.verify_auth(["md5-password"], "PGPKey-80F238C6") + assert not obj.verify_auth(["other-password"]) assert not obj.verify_auth([KEY_CERT_SIGNED_MESSAGE_CORRUPT]) assert not obj.verify_auth([KEY_CERT_SIGNED_MESSAGE_WRONG_KEY]) @@ -387,8 +420,8 @@ def test_parse(self): obj = rpsl_object_from_text(rpsl_text) assert obj.__class__ == RPSLPeeringSet assert not obj.messages.errors() - assert obj.pk() == 'PRNG-SETTEST' - assert obj.parsed_data['tech-c'] == ['PERSON-TEST', 'DUMY2-TEST'] + assert obj.pk() == "PRNG-SETTEST" + assert obj.parsed_data["tech-c"] == ["PERSON-TEST", "DUMY2-TEST"] assert obj.render_rpsl_text() == rpsl_text @@ -402,10 +435,10 @@ def test_parse(self): obj = rpsl_object_from_text(rpsl_text) assert obj.__class__ == RPSLPerson assert not obj.messages.errors() - assert obj.pk() == 'PERSON-TEST' - assert obj.parsed_data['nic-hdl'] == 'PERSON-TEST' + assert obj.pk() == "PERSON-TEST" + assert obj.parsed_data["nic-hdl"] == "PERSON-TEST" assert obj.render_rpsl_text() == rpsl_text - assert obj.references_strong_inbound() == {'admin-c', 'tech-c', 'zone-c'} + assert obj.references_strong_inbound() == {"admin-c", "tech-c", "zone-c"} def test_generate_template(self): template = RPSLPerson().generate_template() @@ -422,9 +455,9 @@ def test_parse(self): obj = rpsl_object_from_text(rpsl_text) assert obj.__class__ == RPSLRole assert not obj.messages.errors() - assert obj.pk() == 'ROLE-TEST' + assert obj.pk() == "ROLE-TEST" assert obj.render_rpsl_text() == rpsl_text - assert obj.references_strong_inbound() == {'admin-c', 'tech-c', 'zone-c'} + assert obj.references_strong_inbound() == {"admin-c", "tech-c", "zone-c"} class TestRPSLRoute: @@ -437,29 +470,29 @@ def test_parse(self): obj = rpsl_object_from_text(rpsl_text) assert obj.__class__ == RPSLRoute assert not obj.messages.errors() - assert obj.pk() == '192.0.2.0/24AS65537' - assert obj.ip_first == IP('192.0.2.0') - assert obj.ip_last == IP('192.0.2.255') - assert obj.prefix == IP('192.0.2.0/24') + assert obj.pk() == "192.0.2.0/24AS65537" + assert obj.ip_first == IP("192.0.2.0") + assert obj.ip_last == IP("192.0.2.255") + assert obj.prefix == IP("192.0.2.0/24") assert obj.prefix_length == 24 assert obj.asn_first == 65537 assert obj.asn_last == 65537 assert obj.ip_version() == 4 assert obj.references_strong_inbound() == set() - expected_text = rpsl_text.replace(' 192.0.02.0/24', ' 192.0.2.0/24') - expected_text = expected_text.replace('rpki-ov-state: valid # should be discarded\n', '') + expected_text = rpsl_text.replace(" 192.0.02.0/24", " 192.0.2.0/24") + expected_text = expected_text.replace("rpki-ov-state: valid # should be discarded\n", "") assert obj.render_rpsl_text() == expected_text def test_missing_pk_nonstrict(self): # In non-strict mode, the parser should not fail validation for missing # attributes, except for those part of the PK. Route is one of the few # objects that has two PK attributes. - missing_pk_route = 'route: 192.0.2.0/24' + missing_pk_route = "route: 192.0.2.0/24" obj = rpsl_object_from_text(missing_pk_route, strict_validation=False) assert obj.__class__ == RPSLRoute errors = obj.messages.errors() - assert len(errors) == 2, f'Unexpected extra errors: {errors}' + assert len(errors) == 2, f"Unexpected extra errors: {errors}" assert 'Primary key attribute "origin" on object route is missing' in errors[0] assert 'Primary key attribute "source" on object route is missing' in errors[1] @@ -478,9 +511,9 @@ def test_parse(self): obj = rpsl_object_from_text(rpsl_text) assert obj.__class__ == RPSLRouteSet assert not obj.messages.errors() - assert obj.pk() == 'RS-TEST' - assert obj.parsed_data['mp-members'] == ['2001:db8::/48'] - assert obj.render_rpsl_text() == rpsl_text.replace('2001:0dB8::/48', '2001:db8::/48') + assert obj.pk() == "RS-TEST" + assert obj.parsed_data["mp-members"] == ["2001:db8::/48"] + assert obj.render_rpsl_text() == rpsl_text.replace("2001:0dB8::/48", "2001:db8::/48") assert obj.references_strong_inbound() == set() @@ -494,15 +527,15 @@ def test_parse(self): obj = rpsl_object_from_text(rpsl_text) assert obj.__class__ == RPSLRoute6 assert not obj.messages.errors() - assert obj.pk() == '2001:DB8::/48AS65537' - assert obj.ip_first == IP('2001:db8::') - assert obj.ip_last == IP('2001:db8::ffff:ffff:ffff:ffff:ffff') - assert obj.prefix == IP('2001:db8::/48') + assert obj.pk() == "2001:DB8::/48AS65537" + assert obj.ip_first == IP("2001:db8::") + assert obj.ip_last == IP("2001:db8::ffff:ffff:ffff:ffff:ffff") + assert obj.prefix == IP("2001:db8::/48") assert obj.prefix_length == 48 assert obj.asn_first == 65537 assert obj.asn_last == 65537 assert obj.ip_version() == 6 - assert obj.parsed_data['mnt-by'] == ['TEST-MNT'] + assert obj.parsed_data["mnt-by"] == ["TEST-MNT"] assert obj.render_rpsl_text() == rpsl_text assert obj.references_strong_inbound() == set() @@ -517,12 +550,12 @@ def test_parse(self): obj = rpsl_object_from_text(rpsl_text) assert obj.__class__ == RPSLRtrSet assert not obj.messages.errors() - assert obj.pk() == 'RTRS-SETTEST' - assert obj.parsed_data['rtr-set'] == 'RTRS-SETTEST' + assert obj.pk() == "RTRS-SETTEST" + assert obj.parsed_data["rtr-set"] == "RTRS-SETTEST" assert obj.referred_strong_objects() == [ - ('admin-c', ['role', 'person'], ['PERSON-TEST']), - ('tech-c', ['role', 'person'], ['PERSON-TEST']), - ('mnt-by', ['mntner'], ['TEST-MNT']) + ("admin-c", ["role", "person"], ["PERSON-TEST"]), + ("tech-c", ["role", "person"], ["PERSON-TEST"]), + ("mnt-by", ["mntner"], ["TEST-MNT"]), ] assert obj.references_strong_inbound() == set() assert obj.render_rpsl_text() == rpsl_text @@ -530,20 +563,18 @@ def test_parse(self): class TestLastModified: def test_authoritative(self, config_override): - config_override({ - 'sources': {'TEST': {'authoritative': True}} - }) + config_override({"sources": {"TEST": {"authoritative": True}}}) rpsl_text = object_sample_mapping[RPSLRtrSet().rpsl_object_class] - obj = rpsl_object_from_text(rpsl_text + 'last-modified: old-value\n') + obj = rpsl_object_from_text(rpsl_text + "last-modified: old-value\n") assert not obj.messages.errors() - last_modified = datetime.datetime(2020, 1, 1, tzinfo=timezone('UTC')) - expected_text = rpsl_text + 'last-modified: 2020-01-01T00:00:00Z\n' + last_modified = datetime.datetime(2020, 1, 1, tzinfo=timezone("UTC")) + expected_text = rpsl_text + "last-modified: 2020-01-01T00:00:00Z\n" assert obj.render_rpsl_text(last_modified=last_modified) == expected_text def test_not_authoritative(self): rpsl_text = object_sample_mapping[RPSLRtrSet().rpsl_object_class] - obj = rpsl_object_from_text(rpsl_text + 'last-modified: old-value\n') + obj = rpsl_object_from_text(rpsl_text + "last-modified: old-value\n") assert not obj.messages.errors() - last_modified = datetime.datetime(2020, 1, 1, tzinfo=timezone('UTC')) - expected_text = rpsl_text + 'last-modified: old-value\n' + last_modified = datetime.datetime(2020, 1, 1, tzinfo=timezone("UTC")) + expected_text = rpsl_text + "last-modified: old-value\n" assert obj.render_rpsl_text(last_modified=last_modified) == expected_text diff --git a/irrd/scopefilter/status.py b/irrd/scopefilter/status.py index de5f6310d..ff1196269 100644 --- a/irrd/scopefilter/status.py +++ b/irrd/scopefilter/status.py @@ -3,9 +3,9 @@ @enum.unique class ScopeFilterStatus(enum.Enum): - in_scope = 'IN_SCOPE' - out_scope_as = 'OUT_SCOPE_AS' - out_scope_prefix = 'OUT_SCOPE_PREFIX' + in_scope = "IN_SCOPE" + out_scope_as = "OUT_SCOPE_AS" + out_scope_prefix = "OUT_SCOPE_PREFIX" @classmethod def is_visible(cls, status: "ScopeFilterStatus"): diff --git a/irrd/scopefilter/tests/test_scopefilter.py b/irrd/scopefilter/tests/test_scopefilter.py index 95181e763..e41820102 100644 --- a/irrd/scopefilter/tests/test_scopefilter.py +++ b/irrd/scopefilter/tests/test_scopefilter.py @@ -6,183 +6,195 @@ from irrd.rpsl.rpsl_objects import rpsl_object_from_text from irrd.storage.database_handler import DatabaseHandler from irrd.storage.queries import RPSLDatabaseQuery -from irrd.utils.rpsl_samples import SAMPLE_AUT_NUM, SAMPLE_ROUTE, SAMPLE_INETNUM +from irrd.utils.rpsl_samples import SAMPLE_AUT_NUM, SAMPLE_INETNUM, SAMPLE_ROUTE from irrd.utils.test_utils import flatten_mock_calls + from ..status import ScopeFilterStatus from ..validators import ScopeFilterValidator class TestScopeFilterValidator: def test_validate(self, config_override): - config_override({ - 'scopefilter': { - 'asns': [ - '23456', - '10-20', - ], - 'prefixes': [ - '10/8', - '192.168.0.0/24' - ], - }, - 'sources': {'TEST-EXCLUDED': {'scopefilter_excluded': True}} - }) + config_override( + { + "scopefilter": { + "asns": [ + "23456", + "10-20", + ], + "prefixes": ["10/8", "192.168.0.0/24"], + }, + "sources": {"TEST-EXCLUDED": {"scopefilter_excluded": True}}, + } + ) validator = ScopeFilterValidator() - assert validator.validate('TEST', IP('192.0.2/24')) == ScopeFilterStatus.in_scope - assert validator.validate('TEST', IP('192.168/24')) == ScopeFilterStatus.out_scope_prefix - assert validator.validate('TEST', IP('10.2.1/24')) == ScopeFilterStatus.out_scope_prefix - assert validator.validate('TEST', IP('192/8')) == ScopeFilterStatus.out_scope_prefix + assert validator.validate("TEST", IP("192.0.2/24")) == ScopeFilterStatus.in_scope + assert validator.validate("TEST", IP("192.168/24")) == ScopeFilterStatus.out_scope_prefix + assert validator.validate("TEST", IP("10.2.1/24")) == ScopeFilterStatus.out_scope_prefix + assert validator.validate("TEST", IP("192/8")) == ScopeFilterStatus.out_scope_prefix - assert validator.validate('TEST', asn=9) == ScopeFilterStatus.in_scope - assert validator.validate('TEST', asn=21) == ScopeFilterStatus.in_scope - assert validator.validate('TEST', asn=20) == ScopeFilterStatus.out_scope_as - assert validator.validate('TEST', asn=10) == ScopeFilterStatus.out_scope_as - assert validator.validate('TEST', asn=15) == ScopeFilterStatus.out_scope_as - assert validator.validate('TEST', asn=23456) == ScopeFilterStatus.out_scope_as + assert validator.validate("TEST", asn=9) == ScopeFilterStatus.in_scope + assert validator.validate("TEST", asn=21) == ScopeFilterStatus.in_scope + assert validator.validate("TEST", asn=20) == ScopeFilterStatus.out_scope_as + assert validator.validate("TEST", asn=10) == ScopeFilterStatus.out_scope_as + assert validator.validate("TEST", asn=15) == ScopeFilterStatus.out_scope_as + assert validator.validate("TEST", asn=23456) == ScopeFilterStatus.out_scope_as - assert validator.validate('TEST-EXCLUDED', IP('192/8')) == ScopeFilterStatus.in_scope - assert validator.validate('TEST-EXCLUDED', asn=20) == ScopeFilterStatus.in_scope + assert validator.validate("TEST-EXCLUDED", IP("192/8")) == ScopeFilterStatus.in_scope + assert validator.validate("TEST-EXCLUDED", asn=20) == ScopeFilterStatus.in_scope # Override to no filter config_override({}) validator.load_filters() - assert validator.validate('TEST', IP('192.168/24')) == ScopeFilterStatus.in_scope - assert validator.validate('TEST', asn=20) == ScopeFilterStatus.in_scope + assert validator.validate("TEST", IP("192.168/24")) == ScopeFilterStatus.in_scope + assert validator.validate("TEST", asn=20) == ScopeFilterStatus.in_scope def test_invalid_input(self): validator = ScopeFilterValidator() with pytest.raises(ValueError) as ve: - validator.validate('TEST') - assert 'must be provided asn or prefix' in str(ve.value) + validator.validate("TEST") + assert "must be provided asn or prefix" in str(ve.value) def test_validate_rpsl_object(self, config_override): validator = ScopeFilterValidator() route_obj = rpsl_object_from_text(SAMPLE_ROUTE) - assert validator.validate_rpsl_object(route_obj) == (ScopeFilterStatus.in_scope, '') + assert validator.validate_rpsl_object(route_obj) == (ScopeFilterStatus.in_scope, "") autnum_obj = rpsl_object_from_text(SAMPLE_AUT_NUM) - assert validator.validate_rpsl_object(autnum_obj) == (ScopeFilterStatus.in_scope, '') + assert validator.validate_rpsl_object(autnum_obj) == (ScopeFilterStatus.in_scope, "") - config_override({ - 'scopefilter': { - 'asns': ['65537'], - }, - }) + config_override( + { + "scopefilter": { + "asns": ["65537"], + }, + } + ) validator.load_filters() result = validator.validate_rpsl_object(route_obj) - assert result == (ScopeFilterStatus.out_scope_as, 'ASN 65537 is out of scope') + assert result == (ScopeFilterStatus.out_scope_as, "ASN 65537 is out of scope") result = validator.validate_rpsl_object(autnum_obj) - assert result == (ScopeFilterStatus.out_scope_as, 'ASN 65537 is out of scope') + assert result == (ScopeFilterStatus.out_scope_as, "ASN 65537 is out of scope") - config_override({ - 'scopefilter': { - 'prefixes': ['192.0.2.0/32'], - }, - }) + config_override( + { + "scopefilter": { + "prefixes": ["192.0.2.0/32"], + }, + } + ) validator.load_filters() result = validator.validate_rpsl_object(route_obj) - assert result == (ScopeFilterStatus.out_scope_prefix, 'prefix 192.0.2.0/24 is out of scope') + assert result == (ScopeFilterStatus.out_scope_prefix, "prefix 192.0.2.0/24 is out of scope") - config_override({ - 'scopefilter': { - 'prefix': ['0/0'], - }, - }) + config_override( + { + "scopefilter": { + "prefix": ["0/0"], + }, + } + ) validator.load_filters() # Ignored object class result = validator.validate_rpsl_object(rpsl_object_from_text(SAMPLE_INETNUM)) - assert result == (ScopeFilterStatus.in_scope, '') + assert result == (ScopeFilterStatus.in_scope, "") def test_validate_all_rpsl_objects(self, config_override, monkeypatch): mock_dh = Mock(spec=DatabaseHandler) mock_dq = Mock(spec=RPSLDatabaseQuery) - monkeypatch.setattr('irrd.scopefilter.validators.RPSLDatabaseQuery', - lambda column_names=None, enable_ordering=True: mock_dq) + monkeypatch.setattr( + "irrd.scopefilter.validators.RPSLDatabaseQuery", + lambda column_names=None, enable_ordering=True: mock_dq, + ) + + config_override( + { + "scopefilter": { + "asns": [ + "23456", + ], + "prefixes": [ + "192.0.2.0/25", + ], + }, + } + ) - config_override({ - 'scopefilter': { - 'asns': [ - '23456', + mock_query_result = iter( + [ + [ + { + # Should become in_scope + "pk": "192.0.2.128/25,AS65547", + "rpsl_pk": "192.0.2.128/25,AS65547", + "prefix": "192.0.2.128/25", + "asn_first": 65547, + "source": "TEST", + "object_class": "route", + "scopefilter_status": ScopeFilterStatus.out_scope_prefix, + }, + { + # Should become out_scope_prefix + "pk": "192.0.2.0/25,AS65547", + "rpsl_pk": "192.0.2.0/25,AS65547", + "prefix": "192.0.2.0/25", + "asn_first": 65547, + "source": "TEST", + "object_class": "route", + "scopefilter_status": ScopeFilterStatus.in_scope, + }, + { + # Should become out_scope_as + "pk": "192.0.2.128/25,AS65547", + "rpsl_pk": "192.0.2.128/25,AS65547", + "prefix": "192.0.2.128/25", + "asn_first": 23456, + "source": "TEST", + "object_class": "route", + "scopefilter_status": ScopeFilterStatus.out_scope_prefix, + }, + { + # Should become out_scope_as + "pk": "AS65547", + "rpsl_pk": "AS65547", + "asn_first": 23456, + "source": "TEST", + "object_class": "aut-num", + "object_text": "text", + "scopefilter_status": ScopeFilterStatus.in_scope, + }, + { + # Should not change + "pk": "192.0.2.128/25,AS65548", + "rpsl_pk": "192.0.2.128/25,AS65548", + "prefix": "192.0.2.128/25", + "asn_first": 65548, + "source": "TEST", + "object_class": "route", + "scopefilter_status": ScopeFilterStatus.in_scope, + }, ], - 'prefixes': [ - '192.0.2.0/25', + [ + { + "pk": "192.0.2.128/25,AS65547", + "object_text": "text-192.0.2.128/25,AS65547", + }, + { + "pk": "192.0.2.0/25,AS65547", + "object_text": "text-192.0.2.0/25,AS65547", + }, + { + "pk": "192.0.2.128/25,AS65547", + "object_text": "text-192.0.2.128/25,AS65547", + }, + { + "pk": "AS65547", + "object_text": "text-AS65547", + }, ], - }, - }) - - mock_query_result = iter([ - [ - { - # Should become in_scope - 'pk': '192.0.2.128/25,AS65547', - 'rpsl_pk': '192.0.2.128/25,AS65547', - 'prefix': '192.0.2.128/25', - 'asn_first': 65547, - 'source': 'TEST', - 'object_class': 'route', - 'scopefilter_status': ScopeFilterStatus.out_scope_prefix, - }, - { - # Should become out_scope_prefix - 'pk': '192.0.2.0/25,AS65547', - 'rpsl_pk': '192.0.2.0/25,AS65547', - 'prefix': '192.0.2.0/25', - 'asn_first': 65547, - 'source': 'TEST', - 'object_class': 'route', - 'scopefilter_status': ScopeFilterStatus.in_scope, - }, - { - # Should become out_scope_as - 'pk': '192.0.2.128/25,AS65547', - 'rpsl_pk': '192.0.2.128/25,AS65547', - 'prefix': '192.0.2.128/25', - 'asn_first': 23456, - 'source': 'TEST', - 'object_class': 'route', - 'scopefilter_status': ScopeFilterStatus.out_scope_prefix, - }, - { - # Should become out_scope_as - 'pk': 'AS65547', - 'rpsl_pk': 'AS65547', - 'asn_first': 23456, - 'source': 'TEST', - 'object_class': 'aut-num', - 'object_text': 'text', - 'scopefilter_status': ScopeFilterStatus.in_scope, - }, - { - # Should not change - 'pk': '192.0.2.128/25,AS65548', - 'rpsl_pk': '192.0.2.128/25,AS65548', - 'prefix': '192.0.2.128/25', - 'asn_first': 65548, - 'source': 'TEST', - 'object_class': 'route', - 'scopefilter_status': ScopeFilterStatus.in_scope, - }, - ], - [ - { - 'pk': '192.0.2.128/25,AS65547', - 'object_text': 'text-192.0.2.128/25,AS65547', - }, - { - 'pk': '192.0.2.0/25,AS65547', - 'object_text': 'text-192.0.2.0/25,AS65547', - }, - { - 'pk': '192.0.2.128/25,AS65547', - 'object_text': 'text-192.0.2.128/25,AS65547', - }, - { - 'pk': 'AS65547', - 'object_text': 'text-AS65547', - }, ] - ]) + ) mock_dh.execute_query = lambda query: next(mock_query_result) validator = ScopeFilterValidator() @@ -193,22 +205,26 @@ def test_validate_all_rpsl_objects(self, config_override, monkeypatch): assert len(now_out_scope_as) == 2 assert len(now_out_scope_prefix) == 1 - assert now_in_scope[0]['rpsl_pk'] == '192.0.2.128/25,AS65547' - assert now_in_scope[0]['old_status'] == ScopeFilterStatus.out_scope_prefix - assert now_in_scope[0]['object_text'] == 'text-192.0.2.128/25,AS65547' + assert now_in_scope[0]["rpsl_pk"] == "192.0.2.128/25,AS65547" + assert now_in_scope[0]["old_status"] == ScopeFilterStatus.out_scope_prefix + assert now_in_scope[0]["object_text"] == "text-192.0.2.128/25,AS65547" - assert now_out_scope_as[0]['rpsl_pk'] == '192.0.2.128/25,AS65547' - assert now_out_scope_as[0]['old_status'] == ScopeFilterStatus.out_scope_prefix - assert now_out_scope_as[0]['object_text'] == 'text-192.0.2.128/25,AS65547' - assert now_out_scope_as[1]['rpsl_pk'] == 'AS65547' - assert now_out_scope_as[1]['old_status'] == ScopeFilterStatus.in_scope - assert now_out_scope_as[1]['object_text'] == 'text-AS65547' + assert now_out_scope_as[0]["rpsl_pk"] == "192.0.2.128/25,AS65547" + assert now_out_scope_as[0]["old_status"] == ScopeFilterStatus.out_scope_prefix + assert now_out_scope_as[0]["object_text"] == "text-192.0.2.128/25,AS65547" + assert now_out_scope_as[1]["rpsl_pk"] == "AS65547" + assert now_out_scope_as[1]["old_status"] == ScopeFilterStatus.in_scope + assert now_out_scope_as[1]["object_text"] == "text-AS65547" - assert now_out_scope_prefix[0]['rpsl_pk'] == '192.0.2.0/25,AS65547' - assert now_out_scope_prefix[0]['old_status'] == ScopeFilterStatus.in_scope - assert now_out_scope_prefix[0]['object_text'] == 'text-192.0.2.0/25,AS65547' + assert now_out_scope_prefix[0]["rpsl_pk"] == "192.0.2.0/25,AS65547" + assert now_out_scope_prefix[0]["old_status"] == ScopeFilterStatus.in_scope + assert now_out_scope_prefix[0]["object_text"] == "text-192.0.2.0/25,AS65547" assert flatten_mock_calls(mock_dq) == [ - ['object_classes', (['route', 'route6', 'aut-num'],), {}], - ['pks', (['192.0.2.128/25,AS65547', '192.0.2.0/25,AS65547', '192.0.2.128/25,AS65547', 'AS65547'],), {}], + ["object_classes", (["route", "route6", "aut-num"],), {}], + [ + "pks", + (["192.0.2.128/25,AS65547", "192.0.2.0/25,AS65547", "192.0.2.128/25,AS65547", "AS65547"],), + {}, + ], ] diff --git a/irrd/scopefilter/validators.py b/irrd/scopefilter/validators.py index cc11a1aeb..dc81eef58 100644 --- a/irrd/scopefilter/validators.py +++ b/irrd/scopefilter/validators.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Optional, Tuple, List, Dict +from typing import Dict, List, Optional, Tuple from IPy import IP @@ -7,6 +7,7 @@ from irrd.rpsl.parser import RPSLObject from irrd.storage.database_handler import DatabaseHandler from irrd.storage.queries import RPSLDatabaseQuery + from .status import ScopeFilterStatus @@ -24,28 +25,30 @@ def load_filters(self): (Re)load the local cache of the configured filters. Also called by __init__ """ - prefixes = get_setting('scopefilter.prefixes', []) + prefixes = get_setting("scopefilter.prefixes", []) self.filtered_prefixes = [IP(prefix) for prefix in prefixes] self.filtered_asns = set() self.filtered_asn_ranges = set() - asn_filters = get_setting('scopefilter.asns', []) + asn_filters = get_setting("scopefilter.asns", []) for asn_filter in asn_filters: - if '-' in str(asn_filter): - start, end = asn_filter.split('-') + if "-" in str(asn_filter): + start, end = asn_filter.split("-") self.filtered_asn_ranges.add((int(start), int(end))) else: self.filtered_asns.add(int(asn_filter)) - def validate(self, source: str, prefix: Optional[IP]=None, asn: Optional[int]=None) -> ScopeFilterStatus: + def validate( + self, source: str, prefix: Optional[IP] = None, asn: Optional[int] = None + ) -> ScopeFilterStatus: """ Validate a prefix and/or ASN, for a particular source. Returns a tuple of a ScopeFilterStatus and an explanation string. """ if not prefix and asn is None: - raise ValueError('Scope Filter validator must be provided asn or prefix') + raise ValueError("Scope Filter validator must be provided asn or prefix") - if get_setting(f'sources.{source}.scopefilter_excluded'): + if get_setting(f"sources.{source}.scopefilter_excluded"): return ScopeFilterStatus.in_scope if prefix: @@ -62,27 +65,28 @@ def validate(self, source: str, prefix: Optional[IP]=None, asn: Optional[int]=No return ScopeFilterStatus.in_scope - def _validate_rpsl_data(self, source: str, object_class: str, prefix: Optional[IP], - asn_first: Optional[int]) -> Tuple[ScopeFilterStatus, str]: + def _validate_rpsl_data( + self, source: str, object_class: str, prefix: Optional[IP], asn_first: Optional[int] + ) -> Tuple[ScopeFilterStatus, str]: """ Validate whether a particular set of RPSL data is in scope. Returns a ScopeFilterStatus. """ out_of_scope = [ScopeFilterStatus.out_scope_prefix, ScopeFilterStatus.out_scope_as] - if object_class not in ['route', 'route6', 'aut-num']: - return ScopeFilterStatus.in_scope, '' + if object_class not in ["route", "route6", "aut-num"]: + return ScopeFilterStatus.in_scope, "" if prefix: prefix_state = self.validate(source, prefix) if prefix_state in out_of_scope: - return prefix_state, f'prefix {prefix} is out of scope' + return prefix_state, f"prefix {prefix} is out of scope" if asn_first is not None: asn_state = self.validate(source, asn=asn_first) if asn_state in out_of_scope: - return asn_state, f'ASN {asn_first} is out of scope' + return asn_state, f"ASN {asn_first} is out of scope" - return ScopeFilterStatus.in_scope, '' + return ScopeFilterStatus.in_scope, "" def validate_rpsl_object(self, rpsl_object: RPSLObject) -> Tuple[ScopeFilterStatus, str]: """ @@ -96,8 +100,9 @@ def validate_rpsl_object(self, rpsl_object: RPSLObject) -> Tuple[ScopeFilterStat rpsl_object.asn_first, ) - def validate_all_rpsl_objects(self, database_handler: DatabaseHandler) -> \ - Tuple[List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]]: + def validate_all_rpsl_objects( + self, database_handler: DatabaseHandler + ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]]: """ Apply the scope filter to all relevant objects. @@ -112,40 +117,43 @@ def validate_all_rpsl_objects(self, database_handler: DatabaseHandler) -> \ Objects where their current status in the DB matches the new validation result, are not included in the return value. """ - columns = ['pk', 'rpsl_pk', 'prefix', 'asn_first', 'source', 'object_class', - 'scopefilter_status'] + columns = ["pk", "rpsl_pk", "prefix", "asn_first", "source", "object_class", "scopefilter_status"] objs_changed: Dict[ScopeFilterStatus, List[Dict[str, str]]] = defaultdict(list) q = RPSLDatabaseQuery(column_names=columns, enable_ordering=False) - q = q.object_classes(['route', 'route6', 'aut-num']) + q = q.object_classes(["route", "route6", "aut-num"]) results = database_handler.execute_query(q) for result in results: - current_status = result['scopefilter_status'] - result['old_status'] = current_status + current_status = result["scopefilter_status"] + result["old_status"] = current_status prefix = None - if result.get('prefix'): - prefix = IP(result['prefix']) + if result.get("prefix"): + prefix = IP(result["prefix"]) new_status, _ = self._validate_rpsl_data( - result['source'], - result['object_class'], + result["source"], + result["object_class"], prefix, - result['asn_first'], + result["asn_first"], ) if new_status != current_status: - result['scopefilter_status'] = new_status + result["scopefilter_status"] = new_status objs_changed[new_status].append(result) # Object text is only retrieved for objects with state changes - pks_to_enrich = [obj['pk'] for objs in objs_changed.values() for obj in objs] - query = RPSLDatabaseQuery(['pk', 'object_text', 'rpki_status', 'route_preference_status'], enable_ordering=False).pks(pks_to_enrich) - rows_per_pk = {row['pk']: row for row in database_handler.execute_query(query)} + pks_to_enrich = [obj["pk"] for objs in objs_changed.values() for obj in objs] + query = RPSLDatabaseQuery( + ["pk", "object_text", "rpki_status", "route_preference_status"], enable_ordering=False + ).pks(pks_to_enrich) + rows_per_pk = {row["pk"]: row for row in database_handler.execute_query(query)} for rpsl_objs in objs_changed.values(): for rpsl_obj in rpsl_objs: - rpsl_obj.update(rows_per_pk[rpsl_obj['pk']]) + rpsl_obj.update(rows_per_pk[rpsl_obj["pk"]]) - return (objs_changed[ScopeFilterStatus.in_scope], - objs_changed[ScopeFilterStatus.out_scope_as], - objs_changed[ScopeFilterStatus.out_scope_prefix]) + return ( + objs_changed[ScopeFilterStatus.in_scope], + objs_changed[ScopeFilterStatus.out_scope_as], + objs_changed[ScopeFilterStatus.out_scope_prefix], + ) diff --git a/irrd/scripts/database_downgrade.py b/irrd/scripts/database_downgrade.py index 3bf11839a..ff9c22136 100755 --- a/irrd/scripts/database_downgrade.py +++ b/irrd/scripts/database_downgrade.py @@ -1,37 +1,46 @@ #!/usr/bin/env python # flake8: noqa: E402 +import argparse import sys +from pathlib import Path -import argparse from alembic import command from alembic.config import Config -from pathlib import Path irrd_root = str(Path(__file__).resolve().parents[2]) sys.path.append(irrd_root) -from irrd.conf import config_init, CONFIG_PATH_DEFAULT +from irrd.conf import CONFIG_PATH_DEFAULT, config_init def run(version): alembic_cfg = Config() - alembic_cfg.set_main_option('script_location', f'{irrd_root}/irrd/storage/alembic') + alembic_cfg.set_main_option("script_location", f"{irrd_root}/irrd/storage/alembic") command.downgrade(alembic_cfg, version) - print(f'Downgrade successful, or already on this version.') + print(f"Downgrade successful, or already on this version.") def main(): # pragma: no cover description = """Downgrade the IRRd SQL database to a particular version by running database migrations. See release notes.""" parser = argparse.ArgumentParser(description=description) - parser.add_argument('--config', dest='config_file_path', type=str, - help=f'use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})') - parser.add_argument('--version', dest='version', type=str, required=True, - help=f'version to downgrade to (see release notes)') + parser.add_argument( + "--config", + dest="config_file_path", + type=str, + help=f"use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})", + ) + parser.add_argument( + "--version", + dest="version", + type=str, + required=True, + help=f"version to downgrade to (see release notes)", + ) args = parser.parse_args() config_init(args.config_file_path) run(args.version) -if __name__ == '__main__': # pragma: no cover +if __name__ == "__main__": # pragma: no cover main() diff --git a/irrd/scripts/database_upgrade.py b/irrd/scripts/database_upgrade.py index 50d140429..b22a0f909 100755 --- a/irrd/scripts/database_upgrade.py +++ b/irrd/scripts/database_upgrade.py @@ -1,37 +1,46 @@ #!/usr/bin/env python # flake8: noqa: E402 +import argparse import sys +from pathlib import Path -import argparse from alembic import command from alembic.config import Config -from pathlib import Path irrd_root = str(Path(__file__).resolve().parents[2]) sys.path.append(irrd_root) -from irrd.conf import config_init, CONFIG_PATH_DEFAULT +from irrd.conf import CONFIG_PATH_DEFAULT, config_init def run(version): alembic_cfg = Config() - alembic_cfg.set_main_option('script_location', f'{irrd_root}/irrd/storage/alembic') + alembic_cfg.set_main_option("script_location", f"{irrd_root}/irrd/storage/alembic") command.upgrade(alembic_cfg, version) - print(f'Upgrade successful, or already on latest version.') + print(f"Upgrade successful, or already on latest version.") def main(): # pragma: no cover description = """Upgrade the IRRd SQL database to a particular version by running database migrations.""" parser = argparse.ArgumentParser(description=description) - parser.add_argument('--config', dest='config_file_path', type=str, - help=f'use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})') - parser.add_argument('--version', dest='version', type=str, default='head', - help=f'version to upgrade to (default: head, i.e. latest)') + parser.add_argument( + "--config", + dest="config_file_path", + type=str, + help=f"use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})", + ) + parser.add_argument( + "--version", + dest="version", + type=str, + default="head", + help=f"version to upgrade to (default: head, i.e. latest)", + ) args = parser.parse_args() config_init(args.config_file_path) run(args.version) -if __name__ == '__main__': # pragma: no cover +if __name__ == "__main__": # pragma: no cover main() diff --git a/irrd/scripts/expire_journal.py b/irrd/scripts/expire_journal.py index da21ecc0e..591c92a84 100755 --- a/irrd/scripts/expire_journal.py +++ b/irrd/scripts/expire_journal.py @@ -16,15 +16,19 @@ logger = logging.getLogger(__name__) sys.path.append(str(Path(__file__).resolve().parents[2])) +from irrd.conf import CONFIG_PATH_DEFAULT, config_init, get_setting from irrd.storage.database_handler import DatabaseHandler from irrd.storage.queries import RPSLDatabaseJournalQuery -from irrd.conf import config_init, CONFIG_PATH_DEFAULT, get_setting def expire_journal(skip_confirmation: bool, expire_before: datetime, source: str): dh = DatabaseHandler() - q = RPSLDatabaseJournalQuery(column_names=["timestamp"]).sources([source]).entries_before_date(expire_before) + q = ( + RPSLDatabaseJournalQuery(column_names=["timestamp"]) + .sources([source]) + .entries_before_date(expire_before) + ) affected_object_count = len(list(dh.execute_query(q))) if not affected_object_count: @@ -73,7 +77,10 @@ def main(): # pragma: no cover help=f"use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})", ) parser.add_argument( - "--expire-before", type=str, required=True, help="expire all entries from before this date (YYYY-MM-DD)" + "--expire-before", + type=str, + required=True, + help="expire all entries from before this date (YYYY-MM-DD)", ) parser.add_argument("--source", type=str, required=True, help="the name of the source to reload") diff --git a/irrd/scripts/irr_rpsl_submit.py b/irrd/scripts/irr_rpsl_submit.py index 9a259659c..c38785658 100755 --- a/irrd/scripts/irr_rpsl_submit.py +++ b/irrd/scripts/irr_rpsl_submit.py @@ -845,7 +845,10 @@ def metadata(metadata_values): "-u", dest="url", type=str, - help="IRRd submission API URL, e.g. https://rr.example.net/v1/submit/ (also set by IRR_RPSL_SUBMIT_URL)", # pylint: disable=C0301 + help=( # pylint: disable=C0301 + "IRRd submission API URL, e.g. https://rr.example.net/v1/submit/ (also set by" + " IRR_RPSL_SUBMIT_URL)" + ), ) add_irrdv3_options(parser) diff --git a/irrd/scripts/load_database.py b/irrd/scripts/load_database.py index d7aa84386..f4e07dcb8 100755 --- a/irrd/scripts/load_database.py +++ b/irrd/scripts/load_database.py @@ -3,10 +3,8 @@ import argparse import logging import sys - from pathlib import Path - """ Load an RPSL file into the database. """ @@ -14,19 +12,23 @@ logger = logging.getLogger(__name__) sys.path.append(str(Path(__file__).resolve().parents[2])) +from irrd.conf import CONFIG_PATH_DEFAULT, config_init, get_setting +from irrd.mirroring.parsers import MirrorFileImportParser from irrd.rpki.validators import BulkRouteROAValidator from irrd.storage.database_handler import DatabaseHandler -from irrd.mirroring.parsers import MirrorFileImportParser -from irrd.conf import config_init, CONFIG_PATH_DEFAULT, get_setting def load(source, filename, serial) -> int: - if any([ - get_setting(f'sources.{source}.import_source'), - get_setting(f'sources.{source}.import_serial_source') - ]): - print(f'Error: to use this command, import_source and import_serial_source ' - f'for source {source} must not be set.') + if any( + [ + get_setting(f"sources.{source}.import_source"), + get_setting(f"sources.{source}.import_serial_source"), + ] + ): + print( + "Error: to use this command, import_source and import_serial_source " + f"for source {source} must not be set." + ) return 2 dh = DatabaseHandler() @@ -34,8 +36,13 @@ def load(source, filename, serial) -> int: dh.delete_all_rpsl_objects_with_journal(source) dh.disable_journaling() parser = MirrorFileImportParser( - source=source, filename=filename, serial=serial, database_handler=dh, - direct_error_return=True, roa_validator=roa_validator) + source=source, + filename=filename, + serial=serial, + database_handler=dh, + direct_error_return=True, + roa_validator=roa_validator, + ) error = parser.run_import() if error: dh.rollback() @@ -43,7 +50,7 @@ def load(source, filename, serial) -> int: dh.commit() dh.close() if error: - print(f'Error occurred while processing object:\n{error}') + print(f"Error occurred while processing object:\n{error}") return 1 return 0 @@ -51,23 +58,26 @@ def load(source, filename, serial) -> int: def main(): # pragma: no cover description = """Load an RPSL file into the database.""" parser = argparse.ArgumentParser(description=description) - parser.add_argument('--config', dest='config_file_path', type=str, - help=f'use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})') - parser.add_argument('--serial', dest='serial', type=int, - help=f'serial number (optional)') - parser.add_argument('--source', dest='source', type=str, required=True, - help=f'name of the source, e.g. NTTCOM') - parser.add_argument('input_file', type=str, - help='the name of a file to read') + parser.add_argument( + "--config", + dest="config_file_path", + type=str, + help=f"use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})", + ) + parser.add_argument("--serial", dest="serial", type=int, help=f"serial number (optional)") + parser.add_argument( + "--source", dest="source", type=str, required=True, help=f"name of the source, e.g. NTTCOM" + ) + parser.add_argument("input_file", type=str, help="the name of a file to read") args = parser.parse_args() config_init(args.config_file_path) - if get_setting('database_readonly'): - print('Unable to run, because database_readonly is set') + if get_setting("database_readonly"): + print("Unable to run, because database_readonly is set") sys.exit(-1) sys.exit(load(args.source, args.input_file, args.serial)) -if __name__ == '__main__': # pragma: no cover +if __name__ == "__main__": # pragma: no cover main() diff --git a/irrd/scripts/load_pgp_keys.py b/irrd/scripts/load_pgp_keys.py index e34a77b09..006c50826 100755 --- a/irrd/scripts/load_pgp_keys.py +++ b/irrd/scripts/load_pgp_keys.py @@ -3,7 +3,6 @@ import argparse import logging import sys - from pathlib import Path """ @@ -13,36 +12,40 @@ logger = logging.getLogger(__name__) sys.path.append(str(Path(__file__).resolve().parents[2])) -from irrd.conf import config_init, CONFIG_PATH_DEFAULT -from irrd.storage.queries import RPSLDatabaseQuery -from irrd.storage.database_handler import DatabaseHandler +from irrd.conf import CONFIG_PATH_DEFAULT, config_init from irrd.rpsl.rpsl_objects import rpsl_object_from_text +from irrd.storage.database_handler import DatabaseHandler +from irrd.storage.queries import RPSLDatabaseQuery + def load_pgp_keys(source: str) -> None: dh = DatabaseHandler() - query = RPSLDatabaseQuery(column_names=['rpsl_pk', 'object_text']) - query = query.sources([source]).object_classes(['key-cert']) + query = RPSLDatabaseQuery(column_names=["rpsl_pk", "object_text"]) + query = query.sources([source]).object_classes(["key-cert"]) keycerts = dh.execute_query(query) for keycert in keycerts: rpsl_pk = keycert["rpsl_pk"] - print(f'Loading key-cert {rpsl_pk}') + print(f"Loading key-cert {rpsl_pk}") # Parsing the keycert in strict mode will load it into the GPG keychain - result = rpsl_object_from_text(keycert['object_text'], strict_validation=True) + result = rpsl_object_from_text(keycert["object_text"], strict_validation=True) if result.messages.errors(): - print(f'Errors in PGP key {rpsl_pk}: {result.messages.errors()}') + print(f"Errors in PGP key {rpsl_pk}: {result.messages.errors()}") - print('All valid key-certs loaded into the GnuPG keychain.') + print("All valid key-certs loaded into the GnuPG keychain.") dh.close() def main(): # pragma: no cover description = """Load all PGP keys from key-cert objects for a specific source into the GnuPG keychain.""" parser = argparse.ArgumentParser(description=description) - parser.add_argument('--config', dest='config_file_path', type=str, - help=f'use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})') - parser.add_argument('source', type=str, - help='the name of the source for which to load PGP keys') + parser.add_argument( + "--config", + dest="config_file_path", + type=str, + help=f"use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})", + ) + parser.add_argument("source", type=str, help="the name of the source for which to load PGP keys") args = parser.parse_args() config_init(args.config_file_path) @@ -50,5 +53,5 @@ def main(): # pragma: no cover load_pgp_keys(args.source) -if __name__ == '__main__': # pragma: no cover +if __name__ == "__main__": # pragma: no cover main() diff --git a/irrd/scripts/load_test.py b/irrd/scripts/load_test.py index c60e271a7..52ff15526 100755 --- a/irrd/scripts/load_test.py +++ b/irrd/scripts/load_test.py @@ -5,49 +5,47 @@ A simple load tester for IRRd. Sends random !g queries. """ -import time - import argparse import random import socket +import time def main(host, port, count): - queries = [b'!!\n'] + queries = [b"!!\n"] for i in range(count): asn = random.randrange(1, 50000) - query = f'!gAS{asn}\n'.encode('ascii') + query = f"!gAS{asn}\n".encode("ascii") queries.append(query) - queries.append(b'!q\n') + queries.append(b"!q\n") s = socket.socket() s.settimeout(600) s.connect((host, port)) - queries_str = b''.join(queries) + queries_str = b"".join(queries) s.sendall(queries_str) start_time = time.perf_counter() while 1: - data = s.recv(1024*1024) + data = s.recv(1024 * 1024) if not data: break elapsed = time.perf_counter() - start_time time_per_query = elapsed / count * 1000 qps = int(count / elapsed) - print(f'Ran {count} queries in {elapsed}s, time per query {time_per_query} ms, {qps} qps') + print(f"Ran {count} queries in {elapsed}s, time per query {time_per_query} ms, {qps} qps") -if __name__ == '__main__': # pragma: no cover +if __name__ == "__main__": # pragma: no cover description = """A simple load tester for IRRd. Sends random !g queries.""" parser = argparse.ArgumentParser(description=description) - parser.add_argument('--count', dest='count', type=int, default=5000, - help=f'number of queries to run (default: 5000)') - parser.add_argument('host', type=str, - help='hostname of instance') - parser.add_argument('port', type=int, - help='port of instance') + parser.add_argument( + "--count", dest="count", type=int, default=5000, help=f"number of queries to run (default: 5000)" + ) + parser.add_argument("host", type=str, help="hostname of instance") + parser.add_argument("port", type=int, help="port of instance") args = parser.parse_args() main(args.host, args.port, args.count) diff --git a/irrd/scripts/mirror_force_reload.py b/irrd/scripts/mirror_force_reload.py index 33b8e3ce1..de8ad899e 100755 --- a/irrd/scripts/mirror_force_reload.py +++ b/irrd/scripts/mirror_force_reload.py @@ -3,7 +3,6 @@ import argparse import logging import sys - from pathlib import Path """ @@ -13,10 +12,10 @@ logger = logging.getLogger(__name__) sys.path.append(str(Path(__file__).resolve().parents[2])) -from irrd.conf import config_init, CONFIG_PATH_DEFAULT, get_setting - +from irrd.conf import CONFIG_PATH_DEFAULT, config_init, get_setting from irrd.storage.database_handler import DatabaseHandler + def set_force_reload(source) -> None: dh = DatabaseHandler() dh.set_force_reload(source) @@ -27,19 +26,22 @@ def set_force_reload(source) -> None: def main(): # pragma: no cover description = """Force a full reload for a mirror.""" parser = argparse.ArgumentParser(description=description) - parser.add_argument('--config', dest='config_file_path', type=str, - help=f'use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})') - parser.add_argument('source', type=str, - help='the name of the source to reload') + parser.add_argument( + "--config", + dest="config_file_path", + type=str, + help=f"use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})", + ) + parser.add_argument("source", type=str, help="the name of the source to reload") args = parser.parse_args() config_init(args.config_file_path) - if get_setting('database_readonly'): - print('Unable to run, because database_readonly is set') + if get_setting("database_readonly"): + print("Unable to run, because database_readonly is set") sys.exit(-1) set_force_reload(args.source) -if __name__ == '__main__': # pragma: no cover +if __name__ == "__main__": # pragma: no cover main() diff --git a/irrd/scripts/query_qa_comparison.py b/irrd/scripts/query_qa_comparison.py index a29eff798..a1f9c4399 100755 --- a/irrd/scripts/query_qa_comparison.py +++ b/irrd/scripts/query_qa_comparison.py @@ -7,24 +7,22 @@ import argparse import difflib -import sys - import re -from IPy import IP -from ordered_set import OrderedSet +import sys from pathlib import Path from typing import Optional +from IPy import IP +from ordered_set import OrderedSet sys.path.append(str(Path(__file__).resolve().parents[2])) from irrd.rpsl.rpsl_objects import rpsl_object_from_text -from irrd.utils.text import splitline_unicodesafe, split_paragraphs_rpsl -from irrd.utils.whois_client import whois_query_irrd, whois_query, WhoisQueryError - +from irrd.utils.text import split_paragraphs_rpsl, splitline_unicodesafe +from irrd.utils.whois_client import WhoisQueryError, whois_query, whois_query_irrd -SSP_QUERIES = ['!6', '!g', '!i'] -ASDOT_RE = re.compile(r'as\d+\.\d*', flags=re.IGNORECASE) +SSP_QUERIES = ["!6", "!g", "!i"] +ASDOT_RE = re.compile(r"as\d+\.\d*", flags=re.IGNORECASE) class QueryComparison: @@ -40,14 +38,14 @@ def __init__(self, input_file, host_reference, port_reference, host_tested, port self.host_tested = host_tested self.port_tested = port_tested - if input_file == '-': + if input_file == "-": f = sys.stdin else: - f = open(input_file, encoding='utf-8', errors='backslashreplace') + f = open(input_file, encoding="utf-8", errors="backslashreplace") for query in f.readlines(): - query = query.strip() + '\n' - if query == '!!\n': + query = query.strip() + "\n" + if query == "!!\n": continue self.queries_run += 1 error_reference = None @@ -56,19 +54,19 @@ def __init__(self, input_file, host_reference, port_reference, host_tested, port response_tested = None # ignore version or singular source queries - if query.lower().startswith('!v') or query.lower().startswith('!s'): + if query.lower().startswith("!v") or query.lower().startswith("!s"): continue - if (query.startswith('-x') and not query.startswith('-x ')) or re.search(ASDOT_RE, query): + if (query.startswith("-x") and not query.startswith("-x ")) or re.search(ASDOT_RE, query): self.queries_invalid += 1 continue # ignore queries asking for NRTM data or mirror serial status - if query.lower().startswith('-g ') or query.lower().startswith('!j'): + if query.lower().startswith("-g ") or query.lower().startswith("!j"): self.queries_mirror += 1 continue - if query.startswith('!'): # IRRD style query + if query.startswith("!"): # IRRD style query try: response_reference = whois_query_irrd(self.host_reference, self.port_reference, query) except ConnectionError as ce: @@ -76,14 +74,14 @@ def __init__(self, input_file, host_reference, port_reference, host_tested, port except WhoisQueryError as wqe: error_reference = str(wqe) except ValueError: - print(f'Query response to {query} invalid') + print(f"Query response to {query} invalid") continue try: response_tested = whois_query_irrd(self.host_tested, self.port_tested, query) except WhoisQueryError as wqe: error_tested = str(wqe) except ValueError: - print(f'Query response to {query} invalid') + print(f"Query response to {query} invalid") continue else: # RIPE style query @@ -95,9 +93,14 @@ def __init__(self, input_file, host_reference, port_reference, host_tested, port # If both produce error messages, don't compare them both_error = error_reference and error_tested - both_comment = (response_reference and response_tested and - response_reference.strip() and response_tested.strip() and - response_reference.strip()[0] == '%' and response_tested.strip()[0] == '%') + both_comment = ( + response_reference + and response_tested + and response_reference.strip() + and response_tested.strip() + and response_reference.strip()[0] == "%" + and response_tested.strip()[0] == "%" + ) if both_error or both_comment: self.queries_both_error += 1 continue @@ -105,23 +108,25 @@ def __init__(self, input_file, host_reference, port_reference, host_tested, port try: cleaned_reference = self.clean(query, response_reference) except ValueError as ve: - print(f'Invalid reference response to query {query.strip()}: {response_reference}: {ve}') + print(f"Invalid reference response to query {query.strip()}: {response_reference}: {ve}") continue try: cleaned_tested = self.clean(query, response_tested) except ValueError as ve: - print(f'Invalid tested response to query {query.strip()}: {response_tested}: {ve}') + print(f"Invalid tested response to query {query.strip()}: {response_tested}: {ve}") continue if cleaned_reference != cleaned_tested: self.queries_different += 1 self.write_inconsistency_report(query, cleaned_reference, cleaned_tested) - print(f'Ran {self.queries_run} objects, {self.queries_different} had different results, ' - f'{self.queries_both_error} produced errors on both instances, ' - f'{self.queries_invalid} invalid queries were skipped, ' - f'{self.queries_mirror} NRTM queries were skipped') + print( + f"Ran {self.queries_run} objects, {self.queries_different} had different results, " + f"{self.queries_both_error} produced errors on both instances, " + f"{self.queries_invalid} invalid queries were skipped, " + f"{self.queries_mirror} NRTM queries were skipped" + ) def clean(self, query: str, response: Optional[str]) -> Optional[str]: """Clean the query response, so that the text can be compared.""" @@ -131,36 +136,36 @@ def clean(self, query: str, response: Optional[str]) -> Optional[str]: response = response.strip().lower() cleaned_result_list = None - if irr_query in SSP_QUERIES or (irr_query == '!r' and query.lower().strip().endswith(',o')): - cleaned_result_list = response.split(' ') - if irr_query in ['!6', '!g'] and cleaned_result_list: + if irr_query in SSP_QUERIES or (irr_query == "!r" and query.lower().strip().endswith(",o")): + cleaned_result_list = response.split(" ") + if irr_query in ["!6", "!g"] and cleaned_result_list: cleaned_result_list = [str(IP(ip)) for ip in cleaned_result_list] if cleaned_result_list: - return ' '.join(sorted(list(set(cleaned_result_list)))) + return " ".join(sorted(list(set(cleaned_result_list)))) else: new_responses = [] for paragraph in split_paragraphs_rpsl(response): rpsl_obj = rpsl_object_from_text(paragraph.strip(), strict_validation=False) new_responses.append(rpsl_obj) - new_responses.sort(key=lambda i: i.parsed_data.get('source', '') + i.rpsl_object_class + i.pk()) + new_responses.sort(key=lambda i: i.parsed_data.get("source", "") + i.rpsl_object_class + i.pk()) texts = [r.render_rpsl_text() for r in new_responses] - return '\n'.join(OrderedSet(texts)) + return "\n".join(OrderedSet(texts)) def write_inconsistency_report(self, query, cleaned_reference, cleaned_tested): """Write a report to disk with details of the query response inconsistency.""" - report = open(f'qout/QR {query.strip().replace("/", "S")[:30]}', 'w') + report = open(f'qout/QR {query.strip().replace("/", "S")[:30]}', "w") diff_str = self.render_diff(query, cleaned_reference, cleaned_tested) - report.write(query.strip() + '\n') - report.write('\n=================================================================\n') + report.write(query.strip() + "\n") + report.write("\n=================================================================\n") if diff_str: - report.write(f'~~~~~~~~~[ diff clean ref->tst ]~~~~~~~~~\n') - report.write(diff_str + '\n') - report.write(f'~~~~~~~~~[ clean ref {self.host_reference}:{self.port_reference} ]~~~~~~~~~\n') - report.write(str(cleaned_reference) + '\n') - report.write(f'~~~~~~~~~[ clean tst {self.host_tested}:{self.port_tested} ]~~~~~~~~~\n') - report.write(str(cleaned_tested) + '\n') - report.write('\n=================================================================\n') + report.write(f"~~~~~~~~~[ diff clean ref->tst ]~~~~~~~~~\n") + report.write(diff_str + "\n") + report.write(f"~~~~~~~~~[ clean ref {self.host_reference}:{self.port_reference} ]~~~~~~~~~\n") + report.write(str(cleaned_reference) + "\n") + report.write(f"~~~~~~~~~[ clean tst {self.host_tested}:{self.port_tested} ]~~~~~~~~~\n") + report.write(str(cleaned_tested) + "\n") + report.write("\n=================================================================\n") report.close() def render_diff(self, query: str, cleaned_reference: str, cleaned_tested: str) -> Optional[str]: @@ -169,34 +174,33 @@ def render_diff(self, query: str, cleaned_reference: str, cleaned_tested: str) - return None irr_query = query[:2].lower() - if irr_query in SSP_QUERIES or (irr_query == '!r' and query.lower().strip().endswith(',o')): - diff_input_reference = list(cleaned_reference.split(' ')) - diff_input_tested = list(cleaned_tested.split(' ')) + if irr_query in SSP_QUERIES or (irr_query == "!r" and query.lower().strip().endswith(",o")): + diff_input_reference = list(cleaned_reference.split(" ")) + diff_input_tested = list(cleaned_tested.split(" ")) else: diff_input_reference = list(splitline_unicodesafe(cleaned_reference)) diff_input_tested = list(splitline_unicodesafe(cleaned_tested)) - diff = list(difflib.unified_diff(diff_input_reference, diff_input_tested, lineterm='')) - diff_str = '\n'.join(diff[2:]) # skip the lines from the diff which would have filenames + diff = list(difflib.unified_diff(diff_input_reference, diff_input_tested, lineterm="")) + diff_str = "\n".join(diff[2:]) # skip the lines from the diff which would have filenames return diff_str def main(): # pragma: no cover description = """Run a list of queries against two IRRD instances, and report significant results.""" parser = argparse.ArgumentParser(description=description) - parser.add_argument('input_file', type=str, - help='the name of a file to read containing queries, or - for stdin') - parser.add_argument('host_reference', type=str, - help='host/IP of the reference IRRD server') - parser.add_argument('port_reference', type=int, - help='port for the reference IRRD server') - parser.add_argument('host_tested', type=str, - help='host/IP of the tested IRRD server') - parser.add_argument('port_tested', type=int, - help='port for the tested IRRD server') + parser.add_argument( + "input_file", type=str, help="the name of a file to read containing queries, or - for stdin" + ) + parser.add_argument("host_reference", type=str, help="host/IP of the reference IRRD server") + parser.add_argument("port_reference", type=int, help="port for the reference IRRD server") + parser.add_argument("host_tested", type=str, help="host/IP of the tested IRRD server") + parser.add_argument("port_tested", type=int, help="port for the tested IRRD server") args = parser.parse_args() - QueryComparison(args.input_file, args.host_reference, args.port_reference, args.host_tested, args.port_tested) + QueryComparison( + args.input_file, args.host_reference, args.port_reference, args.host_tested, args.port_tested + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/irrd/scripts/rpsl_read.py b/irrd/scripts/rpsl_read.py index 94cdeb434..7a31a74d1 100755 --- a/irrd/scripts/rpsl_read.py +++ b/irrd/scripts/rpsl_read.py @@ -6,7 +6,6 @@ """ import argparse import sys - from pathlib import Path from typing import Optional, Set @@ -15,9 +14,9 @@ sys.path.append(str(Path(__file__).resolve().parents[2])) from irrd.conf import CONFIG_PATH_DEFAULT, config_init, get_setting -from irrd.storage.database_handler import DatabaseHandler from irrd.rpsl.parser import UnknownRPSLObjectClassException from irrd.rpsl.rpsl_objects import rpsl_object_from_text +from irrd.storage.database_handler import DatabaseHandler from irrd.utils.text import split_paragraphs_rpsl @@ -34,18 +33,18 @@ def main(self, filename, strict_validation, database, show_info=True): self.database_handler = DatabaseHandler() self.database_handler.disable_journaling() - if filename == '-': # pragma: no cover + if filename == "-": # pragma: no cover f = sys.stdin else: - f = open(filename, encoding='utf-8', errors='backslashreplace') + f = open(filename, encoding="utf-8", errors="backslashreplace") for paragraph in split_paragraphs_rpsl(f): self.parse_object(paragraph, strict_validation) - print(f'Processed {self.obj_parsed} objects, {self.obj_errors} with errors') + print(f"Processed {self.obj_parsed} objects, {self.obj_errors} with errors") if self.obj_unknown: - unknown_formatted = ', '.join(self.unknown_object_classes) - print(f'Ignored {self.obj_unknown} objects due to unknown object classes: {unknown_formatted}') + unknown_formatted = ", ".join(self.unknown_object_classes) + print(f"Ignored {self.obj_unknown} objects due to unknown object classes: {unknown_formatted}") if self.database_handler: self.database_handler.commit() @@ -60,20 +59,20 @@ def parse_object(self, rpsl_text, strict_validation): self.obj_errors += 1 print(rpsl_text.strip()) - print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~') + print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") print(obj.messages) - print('\n=======================================\n') + print("\n=======================================\n") if self.database_handler and obj and not obj.messages.errors(): self.database_handler.upsert_rpsl_object(obj, JournalEntryOrigin.mirror) except UnknownRPSLObjectClassException as e: self.obj_unknown += 1 - self.unknown_object_classes.add(str(e).split(':')[1].strip()) + self.unknown_object_classes.add(str(e).split(":")[1].strip()) except Exception as e: # pragma: no cover - print('=======================================') + print("=======================================") print(rpsl_text) - print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~') + print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") raise e @@ -82,26 +81,38 @@ def main(): # pragma: no cover the parser, the object is printed followed by the messages. Optionally, insert objects into the database.""" parser = argparse.ArgumentParser(description=description) - parser.add_argument('--config', dest='config_file_path', type=str, - help=f'use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})') - parser.add_argument('--hide-info', dest='hide_info', action='store_true', - help='hide INFO messages') - parser.add_argument('--strict', dest='strict_validation', action='store_true', - help='use strict validation (errors on e.g. unknown or missing attributes)') - parser.add_argument('--database-destructive-overwrite', dest='database', action='store_true', - help='insert all valid objects into the IRRD database - OVERWRITING ANY EXISTING ENTRIES, if ' - 'they have the same RPSL primary key and source') - parser.add_argument('input_file', type=str, - help='the name of a file to read, or - for stdin') + parser.add_argument( + "--config", + dest="config_file_path", + type=str, + help=f"use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})", + ) + parser.add_argument("--hide-info", dest="hide_info", action="store_true", help="hide INFO messages") + parser.add_argument( + "--strict", + dest="strict_validation", + action="store_true", + help="use strict validation (errors on e.g. unknown or missing attributes)", + ) + parser.add_argument( + "--database-destructive-overwrite", + dest="database", + action="store_true", + help=( + "insert all valid objects into the IRRD database - OVERWRITING ANY EXISTING ENTRIES, if " + "they have the same RPSL primary key and source" + ), + ) + parser.add_argument("input_file", type=str, help="the name of a file to read, or - for stdin") args = parser.parse_args() config_init(args.config_file_path) - if get_setting('database_readonly'): - print('Unable to run, because database_readonly is set') + if get_setting("database_readonly"): + print("Unable to run, because database_readonly is set") sys.exit(-1) - + RPSLParse().main(args.input_file, args.strict_validation, args.database, not args.hide_info) -if __name__ == '__main__': +if __name__ == "__main__": # pragma: no cover main() diff --git a/irrd/scripts/set_last_modified_auth.py b/irrd/scripts/set_last_modified_auth.py index c90d647ab..8052d2de9 100755 --- a/irrd/scripts/set_last_modified_auth.py +++ b/irrd/scripts/set_last_modified_auth.py @@ -5,7 +5,6 @@ import sys from pathlib import Path - """ Set last-modified attribute on all authoritative objects. """ @@ -13,29 +12,33 @@ logger = logging.getLogger(__name__) sys.path.append(str(Path(__file__).resolve().parents[2])) -from irrd.storage.database_handler import DatabaseHandler -from irrd.conf import config_init, CONFIG_PATH_DEFAULT, get_setting +from irrd.conf import CONFIG_PATH_DEFAULT, config_init, get_setting from irrd.rpsl.rpsl_objects import rpsl_object_from_text +from irrd.storage.database_handler import DatabaseHandler from irrd.storage.models import RPSLDatabaseObject from irrd.storage.queries import RPSLDatabaseQuery + def set_last_modified(): dh = DatabaseHandler() - auth_sources = [k for k, v in get_setting('sources').items() if v.get('authoritative')] - q = RPSLDatabaseQuery(column_names=['pk', 'object_text', 'updated'], enable_ordering=False) + auth_sources = [k for k, v in get_setting("sources").items() if v.get("authoritative")] + q = RPSLDatabaseQuery(column_names=["pk", "object_text", "updated"], enable_ordering=False) q = q.sources(auth_sources) results = list(dh.execute_query(q)) - print(f'Updating {len(results)} objects in sources {auth_sources}') + print(f"Updating {len(results)} objects in sources {auth_sources}") for result in results: - rpsl_obj = rpsl_object_from_text(result['object_text'], strict_validation=False) + rpsl_obj = rpsl_object_from_text(result["object_text"], strict_validation=False) if rpsl_obj.messages.errors(): # pragma: no cover - print(f'Failed to process {rpsl_obj}: {rpsl_obj.messages.errors()}') + print(f"Failed to process {rpsl_obj}: {rpsl_obj.messages.errors()}") continue - new_text = rpsl_obj.render_rpsl_text(result['updated']) - stmt = RPSLDatabaseObject.__table__.update().where( - RPSLDatabaseObject.__table__.c.pk == result['pk']).values( - object_text=new_text, + new_text = rpsl_obj.render_rpsl_text(result["updated"]) + stmt = ( + RPSLDatabaseObject.__table__.update() + .where(RPSLDatabaseObject.__table__.c.pk == result["pk"]) + .values( + object_text=new_text, + ) ) dh.execute_statement(stmt) dh.commit() @@ -45,17 +48,21 @@ def set_last_modified(): def main(): # pragma: no cover description = """Set last-modified attribute on all authoritative objects.""" parser = argparse.ArgumentParser(description=description) - parser.add_argument('--config', dest='config_file_path', type=str, - help=f'use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})') + parser.add_argument( + "--config", + dest="config_file_path", + type=str, + help=f"use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})", + ) args = parser.parse_args() config_init(args.config_file_path) - if get_setting('database_readonly'): - print('Unable to run, because database_readonly is set') + if get_setting("database_readonly"): + print("Unable to run, because database_readonly is set") sys.exit(-1) sys.exit(set_last_modified()) -if __name__ == '__main__': # pragma: no cover +if __name__ == "__main__": # pragma: no cover main() diff --git a/irrd/scripts/submit_changes.py b/irrd/scripts/submit_changes.py index 96a35d32e..2a0d47e93 100755 --- a/irrd/scripts/submit_changes.py +++ b/irrd/scripts/submit_changes.py @@ -13,12 +13,11 @@ """ import argparse import sys - from pathlib import Path sys.path.append(str(Path(__file__).resolve().parents[2])) -from irrd.conf import config_init, CONFIG_PATH_DEFAULT +from irrd.conf import CONFIG_PATH_DEFAULT, config_init from irrd.updates.handler import ChangeSubmissionHandler @@ -27,12 +26,16 @@ def main(data): print(handler.submitter_report_human()) -if __name__ == '__main__': # pragma: no cover +if __name__ == "__main__": # pragma: no cover description = """Process a raw update message, i.e. without email headers. Authentication is still checked, but PGP is not supported. Message is always read from stdin, and a report is printed to stdout.""" parser = argparse.ArgumentParser(description=description) - parser.add_argument('--config', dest='config_file_path', type=str, - help=f'use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})') + parser.add_argument( + "--config", + dest="config_file_path", + type=str, + help=f"use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})", + ) args = parser.parse_args() config_init(args.config_file_path) diff --git a/irrd/scripts/submit_email.py b/irrd/scripts/submit_email.py index 92f4d5b00..2d99d5dd2 100755 --- a/irrd/scripts/submit_email.py +++ b/irrd/scripts/submit_email.py @@ -1,10 +1,9 @@ #!/usr/bin/env python # flake8: noqa: E402 -import sys - import argparse import logging +import sys from pathlib import Path """ @@ -17,7 +16,7 @@ logger = logging.getLogger(__name__) sys.path.append(str(Path(__file__).resolve().parents[2])) -from irrd.conf import config_init, CONFIG_PATH_DEFAULT +from irrd.conf import CONFIG_PATH_DEFAULT, config_init from irrd.updates.email import handle_email_submission @@ -25,8 +24,10 @@ def run(data): try: handle_email_submission(data) except Exception as exc: - logger.critical(f'An exception occurred while attempting to process the following email: {data}', exc_info=exc) - print('An internal error occurred while processing this email.') + logger.critical( + f"An exception occurred while attempting to process the following email: {data}", exc_info=exc + ) + print("An internal error occurred while processing this email.") def main(): # pragma: no cover @@ -34,8 +35,12 @@ def main(): # pragma: no cover is always read from stdin. A report is sent to the user by email, along with any notifications to mntners and others.""" parser = argparse.ArgumentParser(description=description) - parser.add_argument('--config', dest='config_file_path', type=str, - help=f'use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})') + parser.add_argument( + "--config", + dest="config_file_path", + type=str, + help=f"use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})", + ) args = parser.parse_args() config_init(args.config_file_path) @@ -43,5 +48,5 @@ def main(): # pragma: no cover run(sys.stdin.read()) -if __name__ == '__main__': # pragma: no cover +if __name__ == "__main__": # pragma: no cover main() diff --git a/irrd/scripts/tests/test_expire_journal.py b/irrd/scripts/tests/test_expire_journal.py index 2a22345cb..46782fa65 100644 --- a/irrd/scripts/tests/test_expire_journal.py +++ b/irrd/scripts/tests/test_expire_journal.py @@ -5,6 +5,7 @@ from irrd.storage.queries import RPSLDatabaseJournalQuery from irrd.utils.test_utils import MockDatabaseHandler + from ..expire_journal import expire_journal EXPIRY_DATE = datetime(2022, 1, 1, tzinfo=pytz.utc) @@ -12,7 +13,9 @@ class TestExpireJournal: expected_query = ( - RPSLDatabaseJournalQuery(column_names=["timestamp"]).sources(["TEST"]).entries_before_date(EXPIRY_DATE) + RPSLDatabaseJournalQuery(column_names=["timestamp"]) + .sources(["TEST"]) + .entries_before_date(EXPIRY_DATE) ) def test_expire_confirmed(self, capsys, monkeypatch): @@ -87,7 +90,7 @@ def test_expire_rejected(self, capsys, monkeypatch): assert mock_dh.closed [query] = mock_dh.queries - assert query == RPSLDatabaseJournalQuery(column_names=["timestamp"]).sources(["TEST"]).entries_before_date( - EXPIRY_DATE - ) + assert query == RPSLDatabaseJournalQuery(column_names=["timestamp"]).sources( + ["TEST"] + ).entries_before_date(EXPIRY_DATE) assert not mock_dh.other_calls diff --git a/irrd/scripts/tests/test_irr_rpsl_submit.py b/irrd/scripts/tests/test_irr_rpsl_submit.py index 5c5aeccd7..64b1e755e 100755 --- a/irrd/scripts/tests/test_irr_rpsl_submit.py +++ b/irrd/scripts/tests/test_irr_rpsl_submit.py @@ -1,14 +1,15 @@ -import json import io +import json import os -import pytest import re -from urllib import request import subprocess import sys import unittest +from urllib import request from urllib.error import HTTPError +import pytest + from .. import irr_rpsl_submit IRRD_HOST = "fake.example.com" @@ -59,7 +60,9 @@ RPSL_EMPTY = "" RPSL_WHITESPACE = "\n\n\n \t\t\n" RPSL_MINIMAL = "route: 1.2.3.4\norigin: AS65414\n" -RPSL_EXTRA_WHITESPACE = "\n\nroute: 1.2.3.4\norigin: AS65414\n\n\n\nroute: 5.6.8.9\norigin: AS65414\n\n\n\n\n\n" +RPSL_EXTRA_WHITESPACE = ( + "\n\nroute: 1.2.3.4\norigin: AS65414\n\n\n\nroute: 5.6.8.9\norigin: AS65414\n\n\n\n\n\n" +) RPSL_DELETE = f"role: Badgers\ndelete: {DELETE_REASON}" RPSL_DELETE_WITH_TWO_OBJECTS = f"person: Biff Badger\n\nrole: Badgers\ndelete: {DELETE_REASON}" RPSL_WITH_OVERRIDE = f"mnter: Biff\noverride: {OVERRIDE}" @@ -283,7 +286,10 @@ def test_choose_url(self): {"expected": "http://localhost/v1/submit/", "args": ["-h", "localhost"]}, {"expected": "http://localhost:8080/v1/submit/", "args": ["-h", "localhost", "-p", "8080"]}, {"expected": "http://example.com:137/v1/submit/", "args": ["-h", "example.com", "-p", "137"]}, - {"expected": "http://example.com:137/v1/submit/", "args": ["-u", "http://example.com:137/v1/submit/"]}, + { + "expected": "http://example.com:137/v1/submit/", + "args": ["-u", "http://example.com:137/v1/submit/"], + }, {"expected": "http://example.com/v1/submit/", "args": ["-u", "http://example.com/v1/submit/"]}, ] @@ -452,7 +458,9 @@ def test_good_response(self): args = irr_rpsl_submit.get_arguments(options) self.assertEqual(args.url, UNREACHABLE_URL) - irr_rpsl_submit.send_request = lambda rpsl, args: APIResult([APIResultObject().create().succeed()]).to_dict() + irr_rpsl_submit.send_request = lambda rpsl, args: APIResult( + [APIResultObject().create().succeed()] + ).to_dict() result = irr_rpsl_submit.make_request(RPSL_MINIMAL, args) self.assertTrue(result["objects"][0]["successful"]) @@ -683,13 +691,17 @@ def test_010_nonense_options(self): for s in ["-Z", "-X", "-9", "--not-there"]: result = Runner.run([s], ENV_EMPTY, RPSL_EMPTY) self.assertEqual( - result.returncode, EXIT_ARGUMENT_ERROR, f"nonsense switch {s} exits with {EXIT_ARGUMENT_ERROR}" + result.returncode, + EXIT_ARGUMENT_ERROR, + f"nonsense switch {s} exits with {EXIT_ARGUMENT_ERROR}", ) self.assertRegex(result.stderr, REGEX_ONE_OF) def test_010_no_args(self): result = Runner.run([], ENV_EMPTY, RPSL_EMPTY) - self.assertEqual(result.returncode, EXIT_ARGUMENT_ERROR, f"no arguments exits with {EXIT_ARGUMENT_ERROR}") + self.assertEqual( + result.returncode, EXIT_ARGUMENT_ERROR, f"no arguments exits with {EXIT_ARGUMENT_ERROR}" + ) self.assertRegex(result.stderr, REGEX_ONE_OF) def test_020_help(self): @@ -721,19 +733,25 @@ def test_020_dash_o_noop(self): # If we get an error, it should be from the -h, not the -O result = Runner.run(["-h", UNREACHABLE_HOST, "-O", BAD_RESPONSE_HOST], ENV_EMPTY, RPSL_MINIMAL) self.assertEqual( - result.returncode, EXIT_NETWORK_ERROR, "using both -h and -O exits with value appropriate to -h value" + result.returncode, + EXIT_NETWORK_ERROR, + "using both -h and -O exits with value appropriate to -h value", ) self.assertRegex(result.stderr, REGEX_UNREACHABLE) result = Runner.run(["-h", BAD_RESPONSE_HOST, "-O", UNREACHABLE_HOST], ENV_EMPTY, RPSL_MINIMAL) self.assertEqual( - result.returncode, EXIT_NETWORK_ERROR, "using both -h and -O exits with value appropriate to -h value" + result.returncode, + EXIT_NETWORK_ERROR, + "using both -h and -O exits with value appropriate to -h value", ) self.assertRegex(result.stderr, REGEX_NOT_FOUND) def test_030_empty_input_option(self): result = Runner.run(["-u", IRRD_URL], ENV_EMPTY, RPSL_EMPTY) - self.assertEqual(result.returncode, EXIT_INPUT_ERROR, f"empty input with -u exits with {EXIT_INPUT_ERROR}") + self.assertEqual( + result.returncode, EXIT_INPUT_ERROR, f"empty input with -u exits with {EXIT_INPUT_ERROR}" + ) self.assertRegex(result.stderr, REGEX_NO_OBJECTS) def test_030_empty_input_env(self): @@ -749,13 +767,17 @@ def test_030_empty_input_env(self): def test_030_only_whitespace_input(self): result = Runner.run(["-u", IRRD_URL], ENV_EMPTY, RPSL_WHITESPACE) - self.assertEqual(result.returncode, EXIT_INPUT_ERROR, f"whitespace only input exits with {EXIT_INPUT_ERROR}") + self.assertEqual( + result.returncode, EXIT_INPUT_ERROR, f"whitespace only input exits with {EXIT_INPUT_ERROR}" + ) self.assertRegex(result.stderr, REGEX_NO_OBJECTS) def test_030_multiple_object_delete(self): result = Runner.run(["-u", IRRD_URL], ENV_EMPTY, RPSL_DELETE_WITH_TWO_OBJECTS) self.assertEqual( - result.returncode, EXIT_INPUT_ERROR, f"RPSL delete with multiple objects exits with {EXIT_INPUT_ERROR}" + result.returncode, + EXIT_INPUT_ERROR, + f"RPSL delete with multiple objects exits with {EXIT_INPUT_ERROR}", ) self.assertRegex(result.stderr, REGEX_TOO_MANY) @@ -768,7 +790,9 @@ def test_040_unresovlable_host(self): for row in table: result = Runner.run(row, ENV_EMPTY, RPSL_MINIMAL) self.assertEqual( - result.returncode, EXIT_NETWORK_ERROR, f"Unresolvable host in {row[1]} exits with {EXIT_NETWORK_ERROR}" + result.returncode, + EXIT_NETWORK_ERROR, + f"Unresolvable host in {row[1]} exits with {EXIT_NETWORK_ERROR}", ) self.assertRegex(result.stderr, REGEX_UNRESOLVABLE) @@ -781,7 +805,9 @@ def test_040_unreachable_host(self): for row in table: result = Runner.run(row, ENV_EMPTY, RPSL_MINIMAL) self.assertEqual( - result.returncode, EXIT_NETWORK_ERROR, f"Unreachable host in {row[1]} with {EXIT_NETWORK_ERROR}" + result.returncode, + EXIT_NETWORK_ERROR, + f"Unreachable host in {row[1]} with {EXIT_NETWORK_ERROR}", ) self.assertRegex(result.stderr, REGEX_UNREACHABLE) @@ -792,7 +818,9 @@ def test_050_non_json_response(self): for row in table: result = Runner.run(row, ENV_EMPTY, RPSL_MINIMAL) self.assertEqual( - result.returncode, EXIT_RESPONSE_ERROR, f"Bad response URL {row[1]} exits with {EXIT_NETWORK_ERROR}" + result.returncode, + EXIT_RESPONSE_ERROR, + f"Bad response URL {row[1]} exits with {EXIT_NETWORK_ERROR}", ) self.assertRegex(result.stderr, REGEX_BAD_RESPONSE) diff --git a/irrd/scripts/tests/test_load_database.py b/irrd/scripts/tests/test_load_database.py index 3d6ebe479..4b71e3c99 100644 --- a/irrd/scripts/tests/test_load_database.py +++ b/irrd/scripts/tests/test_load_database.py @@ -1,25 +1,28 @@ from unittest.mock import Mock from irrd.utils.test_utils import flatten_mock_calls + from ..load_database import load def test_load_database_success(capsys, monkeypatch): mock_dh = Mock() - monkeypatch.setattr('irrd.scripts.load_database.DatabaseHandler', lambda: mock_dh) + monkeypatch.setattr("irrd.scripts.load_database.DatabaseHandler", lambda: mock_dh) mock_roa_validator = Mock() - monkeypatch.setattr('irrd.scripts.load_database.BulkRouteROAValidator', lambda dh: mock_roa_validator) + monkeypatch.setattr("irrd.scripts.load_database.BulkRouteROAValidator", lambda dh: mock_roa_validator) mock_parser = Mock() - monkeypatch.setattr('irrd.scripts.load_database.MirrorFileImportParser', lambda *args, **kwargs: mock_parser) + monkeypatch.setattr( + "irrd.scripts.load_database.MirrorFileImportParser", lambda *args, **kwargs: mock_parser + ) mock_parser.run_import = lambda: None - assert load('TEST', 'test.db', 42) == 0 + assert load("TEST", "test.db", 42) == 0 assert flatten_mock_calls(mock_dh) == [ - ['delete_all_rpsl_objects_with_journal', ('TEST',), {}], - ['disable_journaling', (), {}], - ['commit', (), {}], - ['close', (), {}] + ["delete_all_rpsl_objects_with_journal", ("TEST",), {}], + ["disable_journaling", (), {}], + ["commit", (), {}], + ["close", (), {}], ] # run_import() call is not included here @@ -29,37 +32,41 @@ def test_load_database_success(capsys, monkeypatch): def test_load_database_import_error(capsys, monkeypatch, caplog): mock_dh = Mock() - monkeypatch.setattr('irrd.scripts.load_database.DatabaseHandler', lambda: mock_dh) + monkeypatch.setattr("irrd.scripts.load_database.DatabaseHandler", lambda: mock_dh) mock_roa_validator = Mock() - monkeypatch.setattr('irrd.scripts.load_database.BulkRouteROAValidator', lambda dh: mock_roa_validator) + monkeypatch.setattr("irrd.scripts.load_database.BulkRouteROAValidator", lambda dh: mock_roa_validator) mock_parser = Mock() - monkeypatch.setattr('irrd.scripts.load_database.MirrorFileImportParser', lambda *args, **kwargs: mock_parser) + monkeypatch.setattr( + "irrd.scripts.load_database.MirrorFileImportParser", lambda *args, **kwargs: mock_parser + ) - mock_parser.run_import = lambda: 'object-parsing-error' + mock_parser.run_import = lambda: "object-parsing-error" - assert load('TEST', 'test.db', 42) == 1 + assert load("TEST", "test.db", 42) == 1 assert flatten_mock_calls(mock_dh) == [ - ['delete_all_rpsl_objects_with_journal', ('TEST',), {}], - ['disable_journaling', (), {}], - ['rollback', (), {}], - ['close', (), {}] + ["delete_all_rpsl_objects_with_journal", ("TEST",), {}], + ["disable_journaling", (), {}], + ["rollback", (), {}], + ["close", (), {}], ] # run_import() call is not included here assert flatten_mock_calls(mock_parser) == [] - assert 'object-parsing-error' not in caplog.text + assert "object-parsing-error" not in caplog.text stdout = capsys.readouterr().out - assert 'Error occurred while processing object:\nobject-parsing-error' in stdout + assert "Error occurred while processing object:\nobject-parsing-error" in stdout def test_reject_import_source_set(capsys, config_override): - config_override({ - 'sources': { - 'TEST': {'import_source': 'import-url'} - }, - }) - assert load('TEST', 'test.db', 42) == 2 + config_override( + { + "sources": {"TEST": {"import_source": "import-url"}}, + } + ) + assert load("TEST", "test.db", 42) == 2 stdout = capsys.readouterr().out - assert 'Error: to use this command, import_source and import_serial_' \ - 'source for source TEST must not be set.' in stdout + assert ( + "Error: to use this command, import_source and import_serial_source for source TEST must not be set." + in stdout + ) diff --git a/irrd/scripts/tests/test_load_pgp_keys.py b/irrd/scripts/tests/test_load_pgp_keys.py index a571a01e0..6924042a0 100644 --- a/irrd/scripts/tests/test_load_pgp_keys.py +++ b/irrd/scripts/tests/test_load_pgp_keys.py @@ -1,38 +1,40 @@ -import pytest from unittest.mock import Mock -from irrd.utils.test_utils import flatten_mock_calls +import pytest + from irrd.utils.rpsl_samples import SAMPLE_KEY_CERT +from irrd.utils.test_utils import flatten_mock_calls + from ..load_pgp_keys import load_pgp_keys -@pytest.mark.usefixtures('tmp_gpg_dir') +@pytest.mark.usefixtures("tmp_gpg_dir") def test_load_pgp_keys(capsys, monkeypatch): mock_dh = Mock() mock_dq = Mock() - monkeypatch.setattr('irrd.scripts.load_pgp_keys.DatabaseHandler', lambda: mock_dh) - monkeypatch.setattr('irrd.scripts.load_pgp_keys.RPSLDatabaseQuery', lambda column_names: mock_dq) + monkeypatch.setattr("irrd.scripts.load_pgp_keys.DatabaseHandler", lambda: mock_dh) + monkeypatch.setattr("irrd.scripts.load_pgp_keys.RPSLDatabaseQuery", lambda column_names: mock_dq) - mock_dh.execute_query = lambda q, : [ + mock_dh.execute_query = lambda q,: [ { - 'rpsl_pk': 'PGPKEY-80F238C6', - 'object_text': SAMPLE_KEY_CERT, + "rpsl_pk": "PGPKEY-80F238C6", + "object_text": SAMPLE_KEY_CERT, }, { - 'rpsl_pk': 'PGPKEY-BAD', - 'object_text': SAMPLE_KEY_CERT.replace('rpYI', 'a'), - } + "rpsl_pk": "PGPKEY-BAD", + "object_text": SAMPLE_KEY_CERT.replace("rpYI", "a"), + }, ] - load_pgp_keys('TEST') + load_pgp_keys("TEST") assert flatten_mock_calls(mock_dh) == [ - ['close', (), {}], + ["close", (), {}], ] assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST'],), {}], - ['object_classes', (['key-cert'],), {}], + ["sources", (["TEST"],), {}], + ["object_classes", (["key-cert"],), {}], ] output = capsys.readouterr().out - assert 'Loading key-cert PGPKEY-80F238C6' in output - assert 'Loading key-cert PGPKEY-BAD' in output - assert 'Unable to read public PGP key' in output + assert "Loading key-cert PGPKEY-80F238C6" in output + assert "Loading key-cert PGPKEY-BAD" in output + assert "Unable to read public PGP key" in output diff --git a/irrd/scripts/tests/test_mirror_force_reload.py b/irrd/scripts/tests/test_mirror_force_reload.py index 0b662b28f..5f2226e59 100644 --- a/irrd/scripts/tests/test_mirror_force_reload.py +++ b/irrd/scripts/tests/test_mirror_force_reload.py @@ -1,17 +1,18 @@ from unittest.mock import Mock from irrd.utils.test_utils import flatten_mock_calls + from ..mirror_force_reload import set_force_reload def test_set_force_reload(capsys, monkeypatch): mock_dh = Mock() - monkeypatch.setattr('irrd.scripts.mirror_force_reload.DatabaseHandler', lambda: mock_dh) + monkeypatch.setattr("irrd.scripts.mirror_force_reload.DatabaseHandler", lambda: mock_dh) - set_force_reload('TEST') + set_force_reload("TEST") assert flatten_mock_calls(mock_dh) == [ - ['set_force_reload', ('TEST', ), {}], - ['commit', (), {}], - ['close', (), {}] + ["set_force_reload", ("TEST",), {}], + ["commit", (), {}], + ["close", (), {}], ] assert not capsys.readouterr().out diff --git a/irrd/scripts/tests/test_rpsl_read.py b/irrd/scripts/tests/test_rpsl_read.py index 2d21bf2a5..e97d327b9 100644 --- a/irrd/scripts/tests/test_rpsl_read.py +++ b/irrd/scripts/tests/test_rpsl_read.py @@ -42,39 +42,39 @@ def test_rpsl_read(capsys, tmpdir, monkeypatch): mock_database_handler = Mock() - monkeypatch.setattr('irrd.scripts.rpsl_read.DatabaseHandler', lambda: mock_database_handler) + monkeypatch.setattr("irrd.scripts.rpsl_read.DatabaseHandler", lambda: mock_database_handler) - tmp_file = tmpdir + '/rpsl_parse_test.rpsl' - fh = open(tmp_file, 'w') + tmp_file = tmpdir + "/rpsl_parse_test.rpsl" + fh = open(tmp_file, "w") fh.write(TEST_DATA) fh.close() RPSLParse().main(filename=tmp_file, strict_validation=True, database=True) captured = capsys.readouterr().out - assert 'ERROR: Unrecognised attribute unknown-obj on object as-block' in captured - assert 'INFO: AS range AS65536 - as065538 was reformatted as AS65536 - AS65538' in captured - assert 'Processed 3 objects, 1 with errors' in captured - assert 'Ignored 1 objects due to unknown object classes: foo-block' in captured + assert "ERROR: Unrecognised attribute unknown-obj on object as-block" in captured + assert "INFO: AS range AS65536 - as065538 was reformatted as AS65536 - AS65538" in captured + assert "Processed 3 objects, 1 with errors" in captured + assert "Ignored 1 objects due to unknown object classes: foo-block" in captured - assert mock_database_handler.mock_calls[0][0] == 'disable_journaling' - assert mock_database_handler.mock_calls[1][0] == 'upsert_rpsl_object' - assert mock_database_handler.mock_calls[1][1][0].pk() == 'AS65536 - AS65538' - assert mock_database_handler.mock_calls[2][0] == 'commit' + assert mock_database_handler.mock_calls[0][0] == "disable_journaling" + assert mock_database_handler.mock_calls[1][0] == "upsert_rpsl_object" + assert mock_database_handler.mock_calls[1][1][0].pk() == "AS65536 - AS65538" + assert mock_database_handler.mock_calls[2][0] == "commit" mock_database_handler.reset_mock() RPSLParse().main(filename=tmp_file, strict_validation=False, database=True) captured = capsys.readouterr().out - assert 'ERROR: Unrecognised attribute unknown-obj on object as-block' not in captured - assert 'INFO: AS range AS65536 - as065538 was reformatted as AS65536 - AS65538' in captured - assert 'Processed 3 objects, 0 with errors' in captured - assert 'Ignored 1 objects due to unknown object classes: foo-block' in captured + assert "ERROR: Unrecognised attribute unknown-obj on object as-block" not in captured + assert "INFO: AS range AS65536 - as065538 was reformatted as AS65536 - AS65538" in captured + assert "Processed 3 objects, 0 with errors" in captured + assert "Ignored 1 objects due to unknown object classes: foo-block" in captured - assert mock_database_handler.mock_calls[0][0] == 'disable_journaling' - assert mock_database_handler.mock_calls[1][0] == 'upsert_rpsl_object' - assert mock_database_handler.mock_calls[1][1][0].pk() == 'AS65536 - AS65538' - assert mock_database_handler.mock_calls[2][0] == 'upsert_rpsl_object' - assert mock_database_handler.mock_calls[2][1][0].pk() == 'AS65536 - AS65538' - assert mock_database_handler.mock_calls[3][0] == 'commit' + assert mock_database_handler.mock_calls[0][0] == "disable_journaling" + assert mock_database_handler.mock_calls[1][0] == "upsert_rpsl_object" + assert mock_database_handler.mock_calls[1][1][0].pk() == "AS65536 - AS65538" + assert mock_database_handler.mock_calls[2][0] == "upsert_rpsl_object" + assert mock_database_handler.mock_calls[2][1][0].pk() == "AS65536 - AS65538" + assert mock_database_handler.mock_calls[3][0] == "commit" mock_database_handler.reset_mock() RPSLParse().main(filename=tmp_file, strict_validation=False, database=False) diff --git a/irrd/scripts/tests/test_set_last_modified_auth.py b/irrd/scripts/tests/test_set_last_modified_auth.py index 53ebb9978..7220088f4 100644 --- a/irrd/scripts/tests/test_set_last_modified_auth.py +++ b/irrd/scripts/tests/test_set_last_modified_auth.py @@ -6,43 +6,43 @@ from irrd.utils.rpsl_samples import SAMPLE_RTR_SET from irrd.utils.test_utils import flatten_mock_calls + from ..set_last_modified_auth import set_last_modified def test_set_last_modified(capsys, monkeypatch, config_override): - config_override({ - 'sources': { - 'TEST': {'authoritative': True}, - 'TEST2': {}, + config_override( + { + "sources": { + "TEST": {"authoritative": True}, + "TEST2": {}, + } } - }) + ) mock_dh = Mock() - monkeypatch.setattr('irrd.scripts.set_last_modified_auth.DatabaseHandler', lambda: mock_dh) + monkeypatch.setattr("irrd.scripts.set_last_modified_auth.DatabaseHandler", lambda: mock_dh) mock_dq = Mock() - monkeypatch.setattr('irrd.scripts.set_last_modified_auth.RPSLDatabaseQuery', lambda column_names, enable_ordering: mock_dq) + monkeypatch.setattr( + "irrd.scripts.set_last_modified_auth.RPSLDatabaseQuery", lambda column_names, enable_ordering: mock_dq + ) object_pk = uuid.uuid4() mock_query_result = [ { - 'pk': object_pk, - 'object_text': SAMPLE_RTR_SET + 'last-modified: old\n', - 'updated': datetime.datetime(2020, 1, 1, tzinfo=timezone('UTC')), + "pk": object_pk, + "object_text": SAMPLE_RTR_SET + "last-modified: old\n", + "updated": datetime.datetime(2020, 1, 1, tzinfo=timezone("UTC")), }, ] mock_dh.execute_query = lambda query: mock_query_result set_last_modified() - assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST'],), {}] - ] - assert mock_dh.mock_calls[0][0] == 'execute_statement' + assert flatten_mock_calls(mock_dq) == [["sources", (["TEST"],), {}]] + assert mock_dh.mock_calls[0][0] == "execute_statement" statement = mock_dh.mock_calls[0][1][0] - new_text = statement.parameters['object_text'] - assert new_text == SAMPLE_RTR_SET + 'last-modified: 2020-01-01T00:00:00Z\n' + new_text = statement.parameters["object_text"] + assert new_text == SAMPLE_RTR_SET + "last-modified: 2020-01-01T00:00:00Z\n" - assert flatten_mock_calls(mock_dh)[1:] == [ - ['commit', (), {}], - ['close', (), {}] - ] + assert flatten_mock_calls(mock_dh)[1:] == [["commit", (), {}], ["close", (), {}]] assert capsys.readouterr().out == "Updating 1 objects in sources ['TEST']\n" diff --git a/irrd/scripts/tests/test_submit_email.py b/irrd/scripts/tests/test_submit_email.py index 38f325597..9002e75a9 100644 --- a/irrd/scripts/tests/test_submit_email.py +++ b/irrd/scripts/tests/test_submit_email.py @@ -5,19 +5,19 @@ def test_submit_email_success(capsys, monkeypatch): mock_handle_email = Mock() - monkeypatch.setattr('irrd.scripts.submit_email.handle_email_submission', lambda data: mock_handle_email) - mock_handle_email.user_report = lambda: 'output' + monkeypatch.setattr("irrd.scripts.submit_email.handle_email_submission", lambda data: mock_handle_email) + mock_handle_email.user_report = lambda: "output" - run('test input') + run("test input") def test_submit_email_fail(capsys, monkeypatch, caplog): - mock_handle_email = Mock(side_effect=Exception('expected-test-error')) - monkeypatch.setattr('irrd.scripts.submit_email.handle_email_submission', mock_handle_email) + mock_handle_email = Mock(side_effect=Exception("expected-test-error")) + monkeypatch.setattr("irrd.scripts.submit_email.handle_email_submission", mock_handle_email) - run('test input') + run("test input") - assert 'expected-test-error' in caplog.text + assert "expected-test-error" in caplog.text stdout = capsys.readouterr().out - assert 'An internal error occurred' in stdout - assert 'expected-test-error' not in stdout + assert "An internal error occurred" in stdout + assert "expected-test-error" not in stdout diff --git a/irrd/scripts/tests/test_submit_update.py b/irrd/scripts/tests/test_submit_update.py index 561001e4d..c7843d2b8 100644 --- a/irrd/scripts/tests/test_submit_update.py +++ b/irrd/scripts/tests/test_submit_update.py @@ -1,15 +1,15 @@ from unittest.mock import Mock -from ..submit_changes import main from ...updates.handler import ChangeSubmissionHandler +from ..submit_changes import main def test_submit_changes(capsys, monkeypatch): mock_update_handler = Mock(spec=ChangeSubmissionHandler) - monkeypatch.setattr('irrd.scripts.submit_changes.ChangeSubmissionHandler', lambda: mock_update_handler) + monkeypatch.setattr("irrd.scripts.submit_changes.ChangeSubmissionHandler", lambda: mock_update_handler) mock_update_handler.load_text_blob = lambda data: mock_update_handler - mock_update_handler.submitter_report_human = lambda: 'output' + mock_update_handler.submitter_report_human = lambda: "output" - main('test input') + main("test input") captured = capsys.readouterr().out - assert captured == 'output\n' + assert captured == "output\n" diff --git a/irrd/scripts/tests/test_update_database.py b/irrd/scripts/tests/test_update_database.py index bff6b4330..d1ade8f91 100644 --- a/irrd/scripts/tests/test_update_database.py +++ b/irrd/scripts/tests/test_update_database.py @@ -1,24 +1,26 @@ from unittest.mock import Mock from irrd.utils.test_utils import flatten_mock_calls + from ..update_database import update def test_update_database_success(capsys, monkeypatch): mock_dh = Mock() - monkeypatch.setattr('irrd.scripts.update_database.DatabaseHandler', lambda enable_preload_update=False: mock_dh) + monkeypatch.setattr( + "irrd.scripts.update_database.DatabaseHandler", lambda enable_preload_update=False: mock_dh + ) mock_roa_validator = Mock() - monkeypatch.setattr('irrd.scripts.update_database.BulkRouteROAValidator', lambda dh: mock_roa_validator) + monkeypatch.setattr("irrd.scripts.update_database.BulkRouteROAValidator", lambda dh: mock_roa_validator) mock_parser = Mock() - monkeypatch.setattr('irrd.scripts.update_database.MirrorUpdateFileImportParser', lambda *args, **kwargs: mock_parser) + monkeypatch.setattr( + "irrd.scripts.update_database.MirrorUpdateFileImportParser", lambda *args, **kwargs: mock_parser + ) mock_parser.run_import = lambda: None - assert update('TEST', 'test.db') == 0 - assert flatten_mock_calls(mock_dh) == [ - ['commit', (), {}], - ['close', (), {}] - ] + assert update("TEST", "test.db") == 0 + assert flatten_mock_calls(mock_dh) == [["commit", (), {}], ["close", (), {}]] # run_import() call is not included here assert flatten_mock_calls(mock_parser) == [] @@ -27,35 +29,38 @@ def test_update_database_success(capsys, monkeypatch): def test_update_database_import_error(capsys, monkeypatch, caplog): mock_dh = Mock() - monkeypatch.setattr('irrd.scripts.update_database.DatabaseHandler', lambda enable_preload_update=False: mock_dh) + monkeypatch.setattr( + "irrd.scripts.update_database.DatabaseHandler", lambda enable_preload_update=False: mock_dh + ) mock_roa_validator = Mock() - monkeypatch.setattr('irrd.scripts.update_database.BulkRouteROAValidator', lambda dh: mock_roa_validator) + monkeypatch.setattr("irrd.scripts.update_database.BulkRouteROAValidator", lambda dh: mock_roa_validator) mock_parser = Mock() - monkeypatch.setattr('irrd.scripts.update_database.MirrorUpdateFileImportParser', lambda *args, **kwargs: mock_parser) + monkeypatch.setattr( + "irrd.scripts.update_database.MirrorUpdateFileImportParser", lambda *args, **kwargs: mock_parser + ) - mock_parser.run_import = lambda: 'object-parsing-error' + mock_parser.run_import = lambda: "object-parsing-error" - assert update('TEST', 'test.db') == 1 - assert flatten_mock_calls(mock_dh) == [ - ['rollback', (), {}], - ['close', (), {}] - ] + assert update("TEST", "test.db") == 1 + assert flatten_mock_calls(mock_dh) == [["rollback", (), {}], ["close", (), {}]] # run_import() call is not included here assert flatten_mock_calls(mock_parser) == [] - assert 'object-parsing-error' not in caplog.text + assert "object-parsing-error" not in caplog.text stdout = capsys.readouterr().out - assert 'Error occurred while processing object:\nobject-parsing-error' in stdout + assert "Error occurred while processing object:\nobject-parsing-error" in stdout def test_reject_import_source_set(capsys, config_override): - config_override({ - 'sources': { - 'TEST': {'import_source': 'import-url'} - }, - }) - assert update('TEST', 'test.db') == 2 + config_override( + { + "sources": {"TEST": {"import_source": "import-url"}}, + } + ) + assert update("TEST", "test.db") == 2 stdout = capsys.readouterr().out - assert 'Error: to use this command, import_source and import_serial_' \ - 'source for source TEST must not be set.' in stdout + assert ( + "Error: to use this command, import_source and import_serial_source for source TEST must not be set." + in stdout + ) diff --git a/irrd/scripts/update_database.py b/irrd/scripts/update_database.py index d3366ece9..d546a81b3 100644 --- a/irrd/scripts/update_database.py +++ b/irrd/scripts/update_database.py @@ -3,10 +3,8 @@ import argparse import logging import sys - from pathlib import Path - """ Update a database based on a RPSL file. """ @@ -14,25 +12,30 @@ logger = logging.getLogger(__name__) sys.path.append(str(Path(__file__).resolve().parents[2])) +from irrd.conf import CONFIG_PATH_DEFAULT, config_init, get_setting +from irrd.mirroring.parsers import MirrorUpdateFileImportParser from irrd.rpki.validators import BulkRouteROAValidator from irrd.storage.database_handler import DatabaseHandler -from irrd.mirroring.parsers import MirrorUpdateFileImportParser -from irrd.conf import config_init, CONFIG_PATH_DEFAULT, get_setting + def update(source, filename) -> int: - if any([ - get_setting(f'sources.{source}.import_source'), - get_setting(f'sources.{source}.import_serial_source') - ]): - print(f'Error: to use this command, import_source and import_serial_source ' - f'for source {source} must not be set.') + if any( + [ + get_setting(f"sources.{source}.import_source"), + get_setting(f"sources.{source}.import_serial_source"), + ] + ): + print( + "Error: to use this command, import_source and import_serial_source " + f"for source {source} must not be set." + ) return 2 dh = DatabaseHandler() roa_validator = BulkRouteROAValidator(dh) parser = MirrorUpdateFileImportParser( - source, filename, database_handler=dh, - direct_error_return=True, roa_validator=roa_validator) + source, filename, database_handler=dh, direct_error_return=True, roa_validator=roa_validator + ) error = parser.run_import() if error: dh.rollback() @@ -40,7 +43,7 @@ def update(source, filename) -> int: dh.commit() dh.close() if error: - print(f'Error occurred while processing object:\n{error}') + print(f"Error occurred while processing object:\n{error}") return 1 return 0 @@ -48,21 +51,25 @@ def update(source, filename) -> int: def main(): # pragma: no cover description = """Update a database based on a RPSL file.""" parser = argparse.ArgumentParser(description=description) - parser.add_argument('--config', dest='config_file_path', type=str, - help=f'use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})') - parser.add_argument('--source', dest='source', type=str, required=True, - help=f'name of the source, e.g. NTTCOM') - parser.add_argument('input_file', type=str, - help='the name of a file to read') + parser.add_argument( + "--config", + dest="config_file_path", + type=str, + help=f"use a different IRRd config file (default: {CONFIG_PATH_DEFAULT})", + ) + parser.add_argument( + "--source", dest="source", type=str, required=True, help=f"name of the source, e.g. NTTCOM" + ) + parser.add_argument("input_file", type=str, help="the name of a file to read") args = parser.parse_args() config_init(args.config_file_path) - if get_setting('database_readonly'): - print('Unable to run, because database_readonly is set') + if get_setting("database_readonly"): + print("Unable to run, because database_readonly is set") sys.exit(-1) sys.exit(update(args.source, args.input_file)) -if __name__ == '__main__': # pragma: no cover +if __name__ == "__main__": # pragma: no cover main() diff --git a/irrd/server/access_check.py b/irrd/server/access_check.py index 2df905feb..c23492763 100644 --- a/irrd/server/access_check.py +++ b/irrd/server/access_check.py @@ -26,8 +26,7 @@ def is_client_permitted(ip: str, access_list_setting: str, default_deny=True, lo client_ip = IP(ip) except (ValueError, AttributeError) as e: if log: - logger.error(f'Rejecting request as client IP could not be read from ' - f'{ip}: {e}') + logger.error(f"Rejecting request as client IP could not be read from {ip}: {e}") return False if client_ip and client_ip.version() == 6: @@ -37,17 +36,17 @@ def is_client_permitted(ip: str, access_list_setting: str, default_deny=True, lo pass access_list_name = get_setting(access_list_setting) - access_list = get_setting(f'access_lists.{access_list_name}') + access_list = get_setting(f"access_lists.{access_list_name}") if not access_list_name or not access_list: if default_deny: if log: - logger.info(f'Rejecting request, access list empty or undefined: {client_ip}') + logger.info(f"Rejecting request, access list empty or undefined: {client_ip}") return False else: return True allowed = bypass_auth or any([client_ip in IP(allowed) for allowed in access_list]) if not allowed and log: - logger.info(f'Rejecting request, IP not in access list {access_list_name}: {client_ip}') + logger.info(f"Rejecting request, IP not in access list {access_list_name}: {client_ip}") return allowed diff --git a/irrd/server/graphql/__init__.py b/irrd/server/graphql/__init__.py index 25fccb063..397a72210 100644 --- a/irrd/server/graphql/__init__.py +++ b/irrd/server/graphql/__init__.py @@ -1 +1 @@ -ENV_UVICORN_WORKER_CONFIG_PATH = 'IRRD_UVICORN_WORKER_CONFIG_PATH' +ENV_UVICORN_WORKER_CONFIG_PATH = "IRRD_UVICORN_WORKER_CONFIG_PATH" diff --git a/irrd/server/graphql/extensions.py b/irrd/server/graphql/extensions.py index 197786fa3..65f3d70d1 100644 --- a/irrd/server/graphql/extensions.py +++ b/irrd/server/graphql/extensions.py @@ -15,6 +15,7 @@ class QueryMetadataExtension(Extension): - Returns SQL queries if SQL trace was enabled - Logs the query and execution time """ + def __init__(self): self.start_timestamp = None self.end_timestamp = None @@ -25,21 +26,21 @@ def request_started(self, context): def format(self, context): data = {} if self.start_timestamp: - data['execution'] = time.perf_counter() - self.start_timestamp - if 'sql_queries' in context: - data['sql_query_count'] = len(context['sql_queries']) - data['sql_queries'] = context['sql_queries'] + data["execution"] = time.perf_counter() - self.start_timestamp + if "sql_queries" in context: + data["sql_query_count"] = len(context["sql_queries"]) + data["sql_queries"] = context["sql_queries"] - query = context['request']._json - if context['request']._json.get('operationName') != 'IntrospectionQuery': + query = context["request"]._json + if context["request"]._json.get("operationName") != "IntrospectionQuery": # Reformat the query to make it fit neatly on a single log line - query['query'] = query['query'].replace(' ', '').replace('\n', ' ').replace('\t', '') - client = context['request'].client.host + query["query"] = query["query"].replace(" ", "").replace("\n", " ").replace("\t", "") + client = context["request"].client.host logger.info(f'{client} ran query in {data.get("execution")}s: {query}') return data -def error_formatter(error: GraphQLError, debug: bool=False): +def error_formatter(error: GraphQLError, debug: bool = False): """ Custom Ariadne error formatter. A generic text is used if the server is not in debug mode and the original error is a diff --git a/irrd/server/graphql/resolvers.py b/irrd/server/graphql/resolvers.py index 97b877795..4bf687fa6 100644 --- a/irrd/server/graphql/resolvers.py +++ b/irrd/server/graphql/resolvers.py @@ -1,21 +1,22 @@ from collections import OrderedDict -from typing import Set, Dict, Optional, List +from typing import Dict, List, Optional, Set import ariadne import graphql +from graphql import GraphQLError, GraphQLResolveInfo from IPy import IP -from graphql import GraphQLResolveInfo, GraphQLError -from irrd.conf import get_setting, RPKI_IRR_PSEUDO_SOURCE -from irrd.rpki.status import RPKIStatus +from irrd.conf import RPKI_IRR_PSEUDO_SOURCE, get_setting from irrd.routepref.status import RoutePreferenceStatus +from irrd.rpki.status import RPKIStatus from irrd.rpsl.rpsl_objects import OBJECT_CLASS_MAPPING, lookup_field_names from irrd.scopefilter.status import ScopeFilterStatus from irrd.server.access_check import is_client_permitted -from irrd.storage.queries import RPSLDatabaseQuery, RPSLDatabaseJournalQuery -from irrd.utils.text import snake_to_camel_case, remove_auth_hashes -from .schema_generator import SchemaGenerator +from irrd.storage.queries import RPSLDatabaseJournalQuery, RPSLDatabaseQuery +from irrd.utils.text import remove_auth_hashes, snake_to_camel_case + from ..query_resolver import QueryResolver +from .schema_generator import SchemaGenerator """ Resolvers resolve GraphQL queries, usually by translating them @@ -33,7 +34,7 @@ def resolve_rpsl_object_type(obj: Dict[str, str], *_) -> str: Find the GraphQL name for an object given its object class. (GraphQL names match RPSL class names.) """ - return OBJECT_CLASS_MAPPING[obj.get('objectClass', obj.get('object_class', ''))].__name__ + return OBJECT_CLASS_MAPPING[obj.get("objectClass", obj.get("object_class", ""))].__name__ @ariadne.convert_kwargs_to_snake_case @@ -44,70 +45,72 @@ def resolve_rpsl_objects(_, info: GraphQLResolveInfo, **kwargs): database query. """ low_specificity_kwargs = { - 'object_class', 'rpki_status', 'scope_filter_status', 'route_preference_status', - 'sources', 'sql_trace' + "object_class", + "rpki_status", + "scope_filter_status", + "route_preference_status", + "sources", + "sql_trace", } # A query is sufficiently specific if it has other fields than listed above, # except that rpki_status is sufficient if it is exclusively selecting on # valid or invalid. - low_specificity = all([ - not (set(kwargs.keys()) - low_specificity_kwargs), - kwargs.get('rpki_status', []) not in [[RPKIStatus.valid], [RPKIStatus.invalid]], - ]) + low_specificity = all( + [ + not (set(kwargs.keys()) - low_specificity_kwargs), + kwargs.get("rpki_status", []) not in [[RPKIStatus.valid], [RPKIStatus.invalid]], + ] + ) if low_specificity: - raise ValueError('Your query must be more specific.') + raise ValueError("Your query must be more specific.") - if kwargs.get('sql_trace'): - info.context['sql_trace'] = True + if kwargs.get("sql_trace"): + info.context["sql_trace"] = True query = RPSLDatabaseQuery( - column_names=_columns_for_graphql_selection(info), - ordered_by_sources=False, - enable_ordering=False + column_names=_columns_for_graphql_selection(info), ordered_by_sources=False, enable_ordering=False ) - if 'record_limit' in kwargs: - query.limit(kwargs['record_limit']) - if 'rpsl_pk' in kwargs: - query.rpsl_pks(kwargs['rpsl_pk']) - if 'object_class' in kwargs: - query.object_classes(kwargs['object_class']) - if 'asn' in kwargs: - query.asns_first(kwargs['asn']) - if 'text_search' in kwargs: - query.text_search(kwargs['text_search']) - if 'rpki_status' in kwargs: - query.rpki_status(kwargs['rpki_status']) + if "record_limit" in kwargs: + query.limit(kwargs["record_limit"]) + if "rpsl_pk" in kwargs: + query.rpsl_pks(kwargs["rpsl_pk"]) + if "object_class" in kwargs: + query.object_classes(kwargs["object_class"]) + if "asn" in kwargs: + query.asns_first(kwargs["asn"]) + if "text_search" in kwargs: + query.text_search(kwargs["text_search"]) + if "rpki_status" in kwargs: + query.rpki_status(kwargs["rpki_status"]) else: query.rpki_status([RPKIStatus.not_found, RPKIStatus.valid]) - if 'scope_filter_status' in kwargs: - query.scopefilter_status(kwargs['scope_filter_status']) + if "scope_filter_status" in kwargs: + query.scopefilter_status(kwargs["scope_filter_status"]) else: query.scopefilter_status([ScopeFilterStatus.in_scope]) - if 'route_preference_status' in kwargs: - query.route_preference_status(kwargs['route_preference_status']) + if "route_preference_status" in kwargs: + query.route_preference_status(kwargs["route_preference_status"]) else: query.route_preference_status([RoutePreferenceStatus.visible]) - all_valid_sources = set(get_setting('sources', {}).keys()) - if get_setting('rpki.roa_source'): + all_valid_sources = set(get_setting("sources", {}).keys()) + if get_setting("rpki.roa_source"): all_valid_sources.add(RPKI_IRR_PSEUDO_SOURCE) - sources_default = set(get_setting('sources_default', [])) + sources_default = set(get_setting("sources_default", [])) - if 'sources' in kwargs: - query.sources(kwargs['sources']) + if "sources" in kwargs: + query.sources(kwargs["sources"]) elif sources_default and sources_default != all_valid_sources: query.sources(list(sources_default)) # All other parameters are generic lookup fields, like `members` for attr, value in kwargs.items(): - attr = attr.replace('_', '-') + attr = attr.replace("_", "-") if attr in lookup_fields: query.lookup_attrs_in([attr], value) - ip_filters = [ - 'ip_exact', 'ip_less_specific', 'ip_more_specific', 'ip_less_specific_one_level', 'ip_any' - ] + ip_filters = ["ip_exact", "ip_less_specific", "ip_more_specific", "ip_less_specific_one_level", "ip_any"] for ip_filter in ip_filters: if ip_filter in kwargs: getattr(query, ip_filter)(IP(kwargs[ip_filter])) @@ -117,45 +120,47 @@ def resolve_rpsl_objects(_, info: GraphQLResolveInfo, **kwargs): def resolve_rpsl_object_mnt_by_objs(rpsl_object, info: GraphQLResolveInfo): """Resolve mntByObjs on RPSL objects""" - return _resolve_subquery(rpsl_object, info, ['mntner'], pk_field='mntBy') + return _resolve_subquery(rpsl_object, info, ["mntner"], pk_field="mntBy") def resolve_rpsl_object_adminc_objs(rpsl_object, info: GraphQLResolveInfo): """Resolve adminCObjs on RPSL objects""" - return _resolve_subquery(rpsl_object, info, ['role', 'person'], pk_field='adminC') + return _resolve_subquery(rpsl_object, info, ["role", "person"], pk_field="adminC") def resolve_rpsl_object_techc_objs(rpsl_object, info: GraphQLResolveInfo): """Resolve techCObjs on RPSL objects""" - return _resolve_subquery(rpsl_object, info, ['role', 'person'], pk_field='techC') + return _resolve_subquery(rpsl_object, info, ["role", "person"], pk_field="techC") def resolve_rpsl_object_members_by_ref_objs(rpsl_object, info: GraphQLResolveInfo): """Resolve mbrsByRefObjs on RPSL objects""" - return _resolve_subquery(rpsl_object, info, ['mntner'], pk_field='mbrsByRef') + return _resolve_subquery(rpsl_object, info, ["mntner"], pk_field="mbrsByRef") def resolve_rpsl_object_member_of_objs(rpsl_object, info: GraphQLResolveInfo): """Resolve memberOfObjs on RPSL objects""" - object_klass = OBJECT_CLASS_MAPPING[rpsl_object['objectClass']] - sub_object_classes = object_klass.fields['member-of'].referring # type: ignore - return _resolve_subquery(rpsl_object, info, sub_object_classes, pk_field='memberOf') + object_klass = OBJECT_CLASS_MAPPING[rpsl_object["objectClass"]] + sub_object_classes = object_klass.fields["member-of"].referring # type: ignore + return _resolve_subquery(rpsl_object, info, sub_object_classes, pk_field="memberOf") def resolve_rpsl_object_members_objs(rpsl_object, info: GraphQLResolveInfo): """Resolve membersObjs on RPSL objects""" - object_klass = OBJECT_CLASS_MAPPING[rpsl_object['objectClass']] - sub_object_classes = object_klass.fields['members'].referring # type: ignore + object_klass = OBJECT_CLASS_MAPPING[rpsl_object["objectClass"]] + sub_object_classes = object_klass.fields["members"].referring # type: ignore # The reference to an aut-num should not be fully resolved, as the # reference is very weak. - if 'aut-num' in sub_object_classes: - sub_object_classes.remove('aut-num') - if 'inet-rtr' in sub_object_classes: - sub_object_classes.remove('inet-rtr') - return _resolve_subquery(rpsl_object, info, sub_object_classes, 'members', sticky_source=False) + if "aut-num" in sub_object_classes: + sub_object_classes.remove("aut-num") + if "inet-rtr" in sub_object_classes: + sub_object_classes.remove("inet-rtr") + return _resolve_subquery(rpsl_object, info, sub_object_classes, "members", sticky_source=False) -def _resolve_subquery(rpsl_object, info: GraphQLResolveInfo, object_classes: List[str], pk_field: str, sticky_source=True): +def _resolve_subquery( + rpsl_object, info: GraphQLResolveInfo, object_classes: List[str], pk_field: str, sticky_source=True +): """ Resolve a subquery, like techCobjs, on an RPSL object, considering a number of object classes, extracting the PK from pk_field. @@ -167,13 +172,11 @@ def _resolve_subquery(rpsl_object, info: GraphQLResolveInfo, object_classes: Lis if not isinstance(pks, list): pks = [pks] query = RPSLDatabaseQuery( - column_names=_columns_for_graphql_selection(info), - ordered_by_sources=False, - enable_ordering=False + column_names=_columns_for_graphql_selection(info), ordered_by_sources=False, enable_ordering=False ) query.object_classes(object_classes).rpsl_pks(pks) if sticky_source: - query.sources([rpsl_object['source']]) + query.sources([rpsl_object["source"]]) return _rpsl_db_query_to_graphql_out(query, info) @@ -181,20 +184,20 @@ def resolve_rpsl_object_journal(rpsl_object, info: GraphQLResolveInfo): """ Resolve a journal subquery on an RPSL object. """ - database_handler = info.context['request'].app.state.database_handler + database_handler = info.context["request"].app.state.database_handler access_list = f"sources.{rpsl_object['source']}.nrtm_access_list" - if not is_client_permitted(info.context['request'].client.host, access_list): + if not is_client_permitted(info.context["request"].client.host, access_list): raise GraphQLError(f"Access to journal denied for source {rpsl_object['source']}") query = RPSLDatabaseJournalQuery() - query.sources([rpsl_object['source']]).rpsl_pk(rpsl_object['rpslPk']) + query.sources([rpsl_object["source"]]).rpsl_pk(rpsl_object["rpslPk"]) for row in database_handler.execute_query(query, refresh_on_error=True): response = {snake_to_camel_case(k): v for k, v in row.items()} - response['operation'] = response['operation'].name - if response['origin']: - response['origin'] = response['origin'].name - if response['objectText']: - response['objectText'] = remove_auth_hashes(response['objectText']) + response["operation"] = response["operation"].name + if response["origin"]: + response["origin"] = response["origin"].name + if response["objectText"]: + response["objectText"] = remove_auth_hashes(response["objectText"]) yield response @@ -208,69 +211,77 @@ def _rpsl_db_query_to_graphql_out(query: RPSLDatabaseQuery, info: GraphQLResolve - Adding the asn and prefix fields if applicable - Ensuring the right fields are returned as a list of strings or a string """ - database_handler = info.context['request'].app.state.database_handler - if info.context.get('sql_trace'): - if 'sql_queries' not in info.context: - info.context['sql_queries'] = [repr(query)] + database_handler = info.context["request"].app.state.database_handler + if info.context.get("sql_trace"): + if "sql_queries" not in info.context: + info.context["sql_queries"] = [repr(query)] else: - info.context['sql_queries'].append(repr(query)) + info.context["sql_queries"].append(repr(query)) for row in database_handler.execute_query(query, refresh_on_error=True): - graphql_result = {snake_to_camel_case(k): v for k, v in row.items() if k != 'parsed_data'} - if 'object_text' in row: - graphql_result['objectText'] = remove_auth_hashes(row['object_text']) - if row.get('ip_first') is not None and row.get('prefix_length'): - graphql_result['prefix'] = row['ip_first'] + '/' + str(row['prefix_length']) - if row.get('asn_first') is not None and row.get('asn_first') == row.get('asn_last'): - graphql_result['asn'] = row['asn_first'] + graphql_result = {snake_to_camel_case(k): v for k, v in row.items() if k != "parsed_data"} + if "object_text" in row: + graphql_result["objectText"] = remove_auth_hashes(row["object_text"]) + if row.get("ip_first") is not None and row.get("prefix_length"): + graphql_result["prefix"] = row["ip_first"] + "/" + str(row["prefix_length"]) + if row.get("asn_first") is not None and row.get("asn_first") == row.get("asn_last"): + graphql_result["asn"] = row["asn_first"] object_type = resolve_rpsl_object_type(row) - for key, value in row.get('parsed_data', dict()).items(): - if key == 'auth': + for key, value in row.get("parsed_data", dict()).items(): + if key == "auth": value = [remove_auth_hashes(v) for v in value] graphql_type = schema.graphql_types[object_type][key] - if graphql_type == 'String' and isinstance(value, list): - value = '\n'.join(value) + if graphql_type == "String" and isinstance(value, list): + value = "\n".join(value) graphql_result[snake_to_camel_case(key)] = value yield graphql_result @ariadne.convert_kwargs_to_snake_case -def resolve_database_status(_, info: GraphQLResolveInfo, sources: Optional[List[str]]=None): +def resolve_database_status(_, info: GraphQLResolveInfo, sources: Optional[List[str]] = None): """Resolve a databaseStatus query""" query_resolver = QueryResolver( - info.context['request'].app.state.preloader, - info.context['request'].app.state.database_handler + info.context["request"].app.state.preloader, info.context["request"].app.state.database_handler ) for name, data in query_resolver.database_status(sources=sources).items(): camel_case_data = OrderedDict(data) - camel_case_data['source'] = name + camel_case_data["source"] = name for key, value in data.items(): camel_case_data[snake_to_camel_case(key)] = value yield camel_case_data @ariadne.convert_kwargs_to_snake_case -def resolve_asn_prefixes(_, info: GraphQLResolveInfo, asns: List[int], ip_version: Optional[int]=None, sources: Optional[List[str]]=None): +def resolve_asn_prefixes( + _, + info: GraphQLResolveInfo, + asns: List[int], + ip_version: Optional[int] = None, + sources: Optional[List[str]] = None, +): """Resolve an asnPrefixes query""" query_resolver = QueryResolver( - info.context['request'].app.state.preloader, - info.context['request'].app.state.database_handler + info.context["request"].app.state.preloader, info.context["request"].app.state.database_handler ) query_resolver.set_query_sources(sources) for asn in asns: - yield dict( - asn=asn, - prefixes=list(query_resolver.routes_for_origin(f'AS{asn}', ip_version)) - ) + yield dict(asn=asn, prefixes=list(query_resolver.routes_for_origin(f"AS{asn}", ip_version))) @ariadne.convert_kwargs_to_snake_case -def resolve_as_set_prefixes(_, info: GraphQLResolveInfo, set_names: List[str], sources: Optional[List[str]]=None, ip_version: Optional[int]=None, exclude_sets: Optional[List[str]]=None, sql_trace: bool=False): +def resolve_as_set_prefixes( + _, + info: GraphQLResolveInfo, + set_names: List[str], + sources: Optional[List[str]] = None, + ip_version: Optional[int] = None, + exclude_sets: Optional[List[str]] = None, + sql_trace: bool = False, +): """Resolve an asSetPrefixes query""" query_resolver = QueryResolver( - info.context['request'].app.state.preloader, - info.context['request'].app.state.database_handler + info.context["request"].app.state.preloader, info.context["request"].app.state.database_handler ) if sql_trace: query_resolver.enable_sql_trace() @@ -281,15 +292,22 @@ def resolve_as_set_prefixes(_, info: GraphQLResolveInfo, set_names: List[str], s prefixes = list(query_resolver.routes_for_as_set(set_name, ip_version, exclude_sets=exclude_sets_set)) yield dict(rpslPk=set_name, prefixes=prefixes) if sql_trace: - info.context['sql_queries'] = query_resolver.retrieve_sql_trace() + info.context["sql_queries"] = query_resolver.retrieve_sql_trace() @ariadne.convert_kwargs_to_snake_case -def resolve_recursive_set_members(_, info: GraphQLResolveInfo, set_names: List[str], depth: int=0, sources: Optional[List[str]]=None, exclude_sets: Optional[List[str]]=None, sql_trace: bool=False): +def resolve_recursive_set_members( + _, + info: GraphQLResolveInfo, + set_names: List[str], + depth: int = 0, + sources: Optional[List[str]] = None, + exclude_sets: Optional[List[str]] = None, + sql_trace: bool = False, +): """Resolve an recursiveSetMembers query""" query_resolver = QueryResolver( - info.context['request'].app.state.preloader, - info.context['request'].app.state.database_handler + info.context["request"].app.state.preloader, info.context["request"].app.state.database_handler ) if sql_trace: query_resolver.enable_sql_trace() @@ -297,11 +315,13 @@ def resolve_recursive_set_members(_, info: GraphQLResolveInfo, set_names: List[s exclude_sets_set = {i.upper() for i in exclude_sets} if exclude_sets else set() query_resolver.set_query_sources(sources) for set_name in set_names_set: - results = query_resolver.members_for_set_per_source(set_name, exclude_sets=exclude_sets_set, depth=depth, recursive=True) + results = query_resolver.members_for_set_per_source( + set_name, exclude_sets=exclude_sets_set, depth=depth, recursive=True + ) for source, members in results.items(): yield dict(rpslPk=set_name, rootSource=source, members=members) if sql_trace: - info.context['sql_queries'] = query_resolver.retrieve_sql_trace() + info.context["sql_queries"] = query_resolver.retrieve_sql_trace() def _columns_for_graphql_selection(info: GraphQLResolveInfo) -> Set[str]: @@ -310,19 +330,19 @@ def _columns_for_graphql_selection(info: GraphQLResolveInfo) -> Set[str]: columns should be retrieved. """ # Some columns are always retrieved - columns = {'object_class', 'source', 'parsed_data', 'rpsl_pk'} + columns = {"object_class", "source", "parsed_data", "rpsl_pk"} fields = _collect_predicate_names(info.field_nodes[0].selection_set.selections) # type: ignore requested_fields = {ariadne.convert_camel_case_to_snake(f) for f in fields} for field in requested_fields: if field in RPSLDatabaseQuery().columns: columns.add(field) - if field == 'asn': - columns.add('asn_first') - columns.add('asn_last') - if field == 'prefix': - columns.add('ip_first') - columns.add('prefix_length') + if field == "asn": + columns.add("asn_first") + columns.add("asn_last") + if field == "prefix": + columns.add("ip_first") + columns.add("prefix_length") return columns diff --git a/irrd/server/graphql/schema_builder.py b/irrd/server/graphql/schema_builder.py index 87b8f58c0..65ddf5d06 100644 --- a/irrd/server/graphql/schema_builder.py +++ b/irrd/server/graphql/schema_builder.py @@ -1,17 +1,25 @@ -from IPy import IP from ariadne import make_executable_schema from asgiref.sync import sync_to_async as sta from graphql import GraphQLError +from IPy import IP -from .resolvers import (resolve_rpsl_objects, resolve_rpsl_object_type, - resolve_database_status, resolve_rpsl_object_mnt_by_objs, - resolve_rpsl_object_member_of_objs, resolve_rpsl_object_members_by_ref_objs, - resolve_rpsl_object_members_objs, resolve_rpsl_object_adminc_objs, - resolve_asn_prefixes, resolve_as_set_prefixes, - resolve_recursive_set_members, resolve_rpsl_object_techc_objs, - resolve_rpsl_object_journal) -from .schema_generator import SchemaGenerator from ...utils.text import clean_ip_value_error +from .resolvers import ( + resolve_as_set_prefixes, + resolve_asn_prefixes, + resolve_database_status, + resolve_recursive_set_members, + resolve_rpsl_object_adminc_objs, + resolve_rpsl_object_journal, + resolve_rpsl_object_member_of_objs, + resolve_rpsl_object_members_by_ref_objs, + resolve_rpsl_object_members_objs, + resolve_rpsl_object_mnt_by_objs, + resolve_rpsl_object_techc_objs, + resolve_rpsl_object_type, + resolve_rpsl_objects, +) +from .schema_generator import SchemaGenerator def build_executable_schema(): @@ -35,19 +43,19 @@ def build_executable_schema(): schema.rpsl_object_type.set_field("mntByObjs", sta(resolve_rpsl_object_mnt_by_objs, False)) schema.rpsl_object_type.set_field("journal", sta(resolve_rpsl_object_journal, False)) for object_type in schema.object_types: - if 'adminCObjs' in schema.graphql_types[object_type.name]: + if "adminCObjs" in schema.graphql_types[object_type.name]: object_type.set_field("adminCObjs", sta(resolve_rpsl_object_adminc_objs, False)) for object_type in schema.object_types: - if 'techCObjs' in schema.graphql_types[object_type.name]: + if "techCObjs" in schema.graphql_types[object_type.name]: object_type.set_field("techCObjs", sta(resolve_rpsl_object_techc_objs, False)) for object_type in schema.object_types: - if 'mbrsByRefObjs' in schema.graphql_types[object_type.name]: + if "mbrsByRefObjs" in schema.graphql_types[object_type.name]: object_type.set_field("mbrsByRefObjs", sta(resolve_rpsl_object_members_by_ref_objs, False)) for object_type in schema.object_types: - if 'memberOfObjs' in schema.graphql_types[object_type.name]: + if "memberOfObjs" in schema.graphql_types[object_type.name]: object_type.set_field("memberOfObjs", sta(resolve_rpsl_object_member_of_objs, False)) for object_type in schema.object_types: - if 'membersObjs' in schema.graphql_types[object_type.name]: + if "membersObjs" in schema.graphql_types[object_type.name]: object_type.set_field("membersObjs", sta(resolve_rpsl_object_members_objs, False)) @schema.asn_scalar_type.value_parser @@ -55,13 +63,13 @@ def parse_asn_scalar(value): try: return int(value) except ValueError: - raise GraphQLError(f'Invalid ASN: {value}; must be numeric') + raise GraphQLError(f"Invalid ASN: {value}; must be numeric") @schema.ip_scalar_type.value_parser def parse_ip_scalar(value): try: return IP(value) except ValueError as ve: - raise GraphQLError(f'Invalid IP: {value}: {clean_ip_value_error(ve)}') + raise GraphQLError(f"Invalid IP: {value}: {clean_ip_value_error(ve)}") return make_executable_schema(schema.type_defs, *schema.object_types) diff --git a/irrd/server/graphql/schema_generator.py b/irrd/server/graphql/schema_generator.py index 97b221bce..41a74801c 100644 --- a/irrd/server/graphql/schema_generator.py +++ b/irrd/server/graphql/schema_generator.py @@ -1,13 +1,19 @@ from collections import OrderedDict, defaultdict -from typing import Optional, Dict, Tuple, List +from typing import Dict, List, Optional, Tuple import ariadne from irrd.routepref.status import RoutePreferenceStatus from irrd.rpki.status import RPKIStatus -from irrd.rpsl.fields import RPSLFieldListMixin, RPSLTextField, RPSLReferenceField -from irrd.rpsl.rpsl_objects import (lookup_field_names, OBJECT_CLASS_MAPPING, RPSLAutNum, - RPSLInetRtr, RPSLPerson, RPSLRole) +from irrd.rpsl.fields import RPSLFieldListMixin, RPSLReferenceField, RPSLTextField +from irrd.rpsl.rpsl_objects import ( + OBJECT_CLASS_MAPPING, + RPSLAutNum, + RPSLInetRtr, + RPSLPerson, + RPSLRole, + lookup_field_names, +) from irrd.scopefilter.status import ScopeFilterStatus from irrd.utils.text import snake_to_camel_case @@ -39,7 +45,8 @@ def __init__(self): self._set_enums() schema = self.enums - schema += """ + schema += ( + """ scalar ASN scalar IP @@ -48,7 +55,9 @@ def __init__(self): } type Query { - rpslObjects(""" + self.rpsl_query_fields + """): [RPSLObject!] + rpslObjects(""" + + self.rpsl_query_fields + + """): [RPSLObject!] databaseStatus(sources: [String!]): [DatabaseStatus] asnPrefixes(asns: [ASN!]!, ipVersion: Int, sources: [String!]): [ASNPrefixes!] asSetPrefixes(setNames: [String!]!, ipVersion: Int, sources: [String!], excludeSets: [String!], sqlTrace: Boolean): [AsSetPrefixes!] @@ -98,10 +107,11 @@ def __init__(self): members: [String!] } """ + ) schema += self.rpsl_object_interface_schema schema += self.rpsl_contact_schema - schema += ''.join(self.rpsl_object_schemas.values()) - schema += 'union RPSLContactUnion = RPSLPerson | RPSLRole' + schema += "".join(self.rpsl_object_schemas.values()) + schema += "union RPSLContactUnion = RPSLPerson | RPSLRole" self.type_defs = ariadne.gql(schema) @@ -110,8 +120,13 @@ def __init__(self): self.rpsl_contact_union_type = ariadne.UnionType("RPSLContactUnion") self.asn_scalar_type = ariadne.ScalarType("ASN") self.ip_scalar_type = ariadne.ScalarType("IP") - self.object_types = [self.query_type, self.rpsl_object_type, self.rpsl_contact_union_type, - self.asn_scalar_type, self.ip_scalar_type] + self.object_types = [ + self.query_type, + self.rpsl_object_type, + self.rpsl_contact_union_type, + self.asn_scalar_type, + self.ip_scalar_type, + ] for name in self.rpsl_object_schemas.keys(): self.object_types.append(ariadne.ObjectType(name)) @@ -129,34 +144,34 @@ def _set_rpsl_query_fields(self): This includes all fields from all objects, along with a few special fields. """ - string_list_fields = {'rpsl_pk', 'sources', 'object_class'}.union(lookup_field_names()) - params = [snake_to_camel_case(p) + ': [String!]' for p in sorted(string_list_fields)] + string_list_fields = {"rpsl_pk", "sources", "object_class"}.union(lookup_field_names()) + params = [snake_to_camel_case(p) + ": [String!]" for p in sorted(string_list_fields)] params += [ - 'ipExact: IP', - 'ipLessSpecific: IP', - 'ipLessSpecificOneLevel: IP', - 'ipMoreSpecific: IP', - 'ipAny: IP', - 'asn: [ASN!]', - 'rpkiStatus: [RPKIStatus!]', - 'scopeFilterStatus: [ScopeFilterStatus!]', - 'routePreferenceStatus: [RoutePreferenceStatus!]', - 'textSearch: String', - 'recordLimit: Int', - 'sqlTrace: Boolean', + "ipExact: IP", + "ipLessSpecific: IP", + "ipLessSpecificOneLevel: IP", + "ipMoreSpecific: IP", + "ipAny: IP", + "asn: [ASN!]", + "rpkiStatus: [RPKIStatus!]", + "scopeFilterStatus: [ScopeFilterStatus!]", + "routePreferenceStatus: [RoutePreferenceStatus!]", + "textSearch: String", + "recordLimit: Int", + "sqlTrace: Boolean", ] - self.rpsl_query_fields = ', '.join(params) + self.rpsl_query_fields = ", ".join(params) def _set_enums(self): """ Create the schema for enums of RPKI, scope filter and route preference.. """ - self.enums = '' + self.enums = "" for enum in [RPKIStatus, ScopeFilterStatus, RoutePreferenceStatus]: - self.enums += f'enum {enum.__name__} {{\n' + self.enums += f"enum {enum.__name__} {{\n" for value in enum: - self.enums += f' {value.name}\n' - self.enums += '}\n\n' + self.enums += f" {value.name}\n" + self.enums += "}\n\n" def _set_rpsl_object_interface_schema(self): """ @@ -170,10 +185,10 @@ def _set_rpsl_object_interface_schema(self): else: common_fields = common_fields.intersection(set(rpsl_object_class.fields.keys())) common_fields = list(common_fields) - common_fields = ['rpslPk', 'objectClass', 'objectText', 'updated'] + common_fields + common_fields = ["rpslPk", "objectClass", "objectText", "updated"] + common_fields common_field_dict = self._dict_for_common_fields(common_fields) - common_field_dict['journal'] = '[RPSLJournalEntry]' - schema = self._generate_schema_str('RPSLObject', 'interface', common_field_dict) + common_field_dict["journal"] = "[RPSLJournalEntry]" + schema = self._generate_schema_str("RPSLObject", "interface", common_field_dict) self.rpsl_object_interface_schema = schema def _set_rpsl_contact_schema(self): @@ -182,9 +197,9 @@ def _set_rpsl_contact_schema(self): RPSLPerson and RPSLRole, as they are so similar. """ common_fields = set(RPSLPerson.fields.keys()).intersection(set(RPSLRole.fields.keys())) - common_fields = common_fields.union({'rpslPk', 'objectClass', 'objectText', 'updated'}) + common_fields = common_fields.union({"rpslPk", "objectClass", "objectText", "updated"}) common_field_dict = self._dict_for_common_fields(list(common_fields)) - schema = self._generate_schema_str('RPSLContact', 'interface', common_field_dict) + schema = self._generate_schema_str("RPSLContact", "interface", common_field_dict) self.rpsl_contact_schema = schema def _dict_for_common_fields(self, common_fields: List[str]): @@ -195,12 +210,11 @@ def _dict_for_common_fields(self, common_fields: List[str]): rpsl_field = RPSLPerson.fields[field_name] graphql_type = self._graphql_type_for_rpsl_field(rpsl_field) - reference_name, reference_type = self._grapql_type_for_reference_field( - field_name, rpsl_field) + reference_name, reference_type = self._grapql_type_for_reference_field(field_name, rpsl_field) if reference_name and reference_type: common_field_dict[reference_name] = reference_type except KeyError: - graphql_type = 'String' + graphql_type = "String" common_field_dict[snake_to_camel_case(field_name)] = graphql_type return common_field_dict @@ -215,11 +229,11 @@ def _set_rpsl_object_schemas(self): for object_class, klass in OBJECT_CLASS_MAPPING.items(): object_name = klass.__name__ graphql_fields = OrderedDict() - graphql_fields['rpslPk'] = 'String' - graphql_fields['objectClass'] = 'String' - graphql_fields['objectText'] = 'String' - graphql_fields['updated'] = 'String' - graphql_fields['journal'] = '[RPSLJournalEntry]' + graphql_fields["rpslPk"] = "String" + graphql_fields["objectClass"] = "String" + graphql_fields["objectText"] = "String" + graphql_fields["updated"] = "String" + graphql_fields["journal"] = "[RPSLJournalEntry]" for field_name, field in klass.fields.items(): graphql_type = self._graphql_type_for_rpsl_field(field) graphql_fields[snake_to_camel_case(field_name)] = graphql_type @@ -231,22 +245,22 @@ def _set_rpsl_object_schemas(self): self.graphql_types[object_name][reference_name] = reference_type for field_name in klass.field_extracts: - if field_name.startswith('asn'): - graphql_type = 'ASN' - elif field_name == 'prefix': - graphql_type = 'IP' - elif field_name == 'prefix_length': - graphql_type = 'Int' + if field_name.startswith("asn"): + graphql_type = "ASN" + elif field_name == "prefix": + graphql_type = "IP" + elif field_name == "prefix_length": + graphql_type = "Int" else: - graphql_type = 'String' + graphql_type = "String" graphql_fields[snake_to_camel_case(field_name)] = graphql_type if klass.is_route: - graphql_fields['rpkiStatus'] = 'RPKIStatus' - graphql_fields['rpkiMaxLength'] = 'Int' - self.graphql_types[object_name]['rpki_max_length'] = 'Int' - graphql_fields['routePreferenceStatus'] = 'RoutePreferenceStatus' - implements = 'RPSLContact & RPSLObject' if klass in [RPSLPerson, RPSLRole] else 'RPSLObject' - schema = self._generate_schema_str(object_name, 'type', graphql_fields, implements) + graphql_fields["rpkiStatus"] = "RPKIStatus" + graphql_fields["rpkiMaxLength"] = "Int" + self.graphql_types[object_name]["rpki_max_length"] = "Int" + graphql_fields["routePreferenceStatus"] = "RoutePreferenceStatus" + implements = "RPSLContact & RPSLObject" if klass in [RPSLPerson, RPSLRole] else "RPSLObject" + schema = self._generate_schema_str(object_name, "type", graphql_fields, implements) schemas[object_name] = schema self.rpsl_object_schemas = schemas @@ -257,10 +271,12 @@ def _graphql_type_for_rpsl_field(self, field: RPSLTextField) -> str: can occur multiple times. """ if RPSLFieldListMixin in field.__class__.__bases__ or field.multiple: - return '[String!]' - return 'String' + return "[String!]" + return "String" - def _grapql_type_for_reference_field(self, field_name: str, rpsl_field: RPSLTextField) -> Tuple[Optional[str], Optional[str]]: + def _grapql_type_for_reference_field( + self, field_name: str, rpsl_field: RPSLTextField + ) -> Tuple[Optional[str], Optional[str]]: """ Return the GraphQL name and type for a reference field. For example, for a field "admin-c" that refers to person/role, @@ -268,31 +284,33 @@ def _grapql_type_for_reference_field(self, field_name: str, rpsl_field: RPSLText Some fields are excluded because they are syntactical references, not real references. """ - if isinstance(rpsl_field, RPSLReferenceField) and getattr(rpsl_field, 'referring', None): + if isinstance(rpsl_field, RPSLReferenceField) and getattr(rpsl_field, "referring", None): rpsl_field.resolve_references() - graphql_name = snake_to_camel_case(field_name) + 'Objs' + graphql_name = snake_to_camel_case(field_name) + "Objs" grapql_referring = set(rpsl_field.referring_object_classes) if RPSLAutNum in grapql_referring: grapql_referring.remove(RPSLAutNum) if RPSLInetRtr in grapql_referring: grapql_referring.remove(RPSLInetRtr) if grapql_referring == {RPSLPerson, RPSLRole}: - graphql_type = '[RPSLContactUnion!]' + graphql_type = "[RPSLContactUnion!]" else: - graphql_type = '[' + grapql_referring.pop().__name__ + '!]' + graphql_type = "[" + grapql_referring.pop().__name__ + "!]" return graphql_name, graphql_type return None, None - def _generate_schema_str(self, name: str, graphql_type: str, fields: Dict[str, str], implements: Optional[str]=None) -> str: + def _generate_schema_str( + self, name: str, graphql_type: str, fields: Dict[str, str], implements: Optional[str] = None + ) -> str: """ Generate a schema string for a given name, object type and dict of fields. """ - schema = f'{graphql_type} {name} ' + schema = f"{graphql_type} {name} " if implements: - schema += f'implements {implements} ' - schema += '{\n' + schema += f"implements {implements} " + schema += "{\n" for field, field_type in fields.items(): - schema += f' {field}: {field_type}\n' - schema += '}\n\n' + schema += f" {field}: {field_type}\n" + schema += "}\n\n" return schema diff --git a/irrd/server/graphql/tests/test_extensions.py b/irrd/server/graphql/tests/test_extensions.py index 8ec5bd1e7..659eba5d7 100644 --- a/irrd/server/graphql/tests/test_extensions.py +++ b/irrd/server/graphql/tests/test_extensions.py @@ -7,38 +7,40 @@ def test_query_metedata_extension(caplog): extension = QueryMetadataExtension() - mock_request = HTTPConnection({ - 'type': 'http', - 'client': ('127.0.0.1', '8000'), - }) + mock_request = HTTPConnection( + { + "type": "http", + "client": ("127.0.0.1", "8000"), + } + ) mock_request._json = { - 'operationName': 'operation', - 'query': 'graphql query', + "operationName": "operation", + "query": "graphql query", } context = { - 'sql_queries': ['sql query'], - 'request': mock_request, + "sql_queries": ["sql query"], + "request": mock_request, } extension.request_started(context) result = extension.format(context) - assert '127.0.0.1 ran query in ' in caplog.text + assert "127.0.0.1 ran query in " in caplog.text assert ": {'operationName': 'operation', 'query': 'graphqlquery'}" in caplog.text - assert result['execution'] < 3 - assert result['sql_query_count'] == 1 - assert result['sql_queries'] == ['sql query'] + assert result["execution"] < 3 + assert result["sql_query_count"] == 1 + assert result["sql_queries"] == ["sql query"] def test_error_formatter(): # Regular GraphQL error should always be passed - error = GraphQLError(message='error') + error = GraphQLError(message="error") result = error_formatter(error) - assert result['message'] == 'error' + assert result["message"] == "error" # If original_error is something else, hide except when in debug mode - error = GraphQLError(message='error', original_error=ValueError()) + error = GraphQLError(message="error", original_error=ValueError()) result = error_formatter(error) - assert result['message'] == 'Internal server error' + assert result["message"] == "Internal server error" result = error_formatter(error, debug=True) - assert result['message'] == 'error' - assert result['extensions'] == {'exception': None} + assert result["message"] == "error" + assert result["extensions"] == {"exception": None} diff --git a/irrd/server/graphql/tests/test_resolvers.py b/irrd/server/graphql/tests/test_resolvers.py index 9373a2a6d..c7453e574 100644 --- a/irrd/server/graphql/tests/test_resolvers.py +++ b/irrd/server/graphql/tests/test_resolvers.py @@ -1,8 +1,8 @@ from unittest.mock import Mock import pytest -from IPy import IP from graphql import GraphQLError +from IPy import IP from starlette.requests import HTTPConnection from irrd.routepref.status import RoutePreferenceStatus @@ -12,82 +12,93 @@ from irrd.storage.database_handler import DatabaseHandler from irrd.storage.models import DatabaseOperation, JournalEntryOrigin from irrd.storage.preload import Preloader -from irrd.storage.queries import RPSLDatabaseQuery, RPSLDatabaseJournalQuery +from irrd.storage.queries import RPSLDatabaseJournalQuery, RPSLDatabaseQuery from irrd.utils.test_utils import flatten_mock_calls + from .. import resolvers -EXPECTED_RPSL_GRAPHQL_OUTPUT = [{ - 'rpslPk': '192.0.2.0/25,AS65547', - 'objectClass': 'route', - 'objectText': 'object text\nauth: CRYPT-PW DummyValue # Filtered for security', - 'operation': DatabaseOperation.add_or_update, - 'rpkiStatus': RPKIStatus.not_found, - 'scopefilterStatus': ScopeFilterStatus.out_scope_as, - 'routePreferenceStatus': RoutePreferenceStatus.suppressed, - 'source': 'TEST1', - 'route': '192.0.2.0/25', - 'origin': 'AS65547', - 'mntBy': 'MNT-TEST', - 'asn': 65547, - 'asnFirst': 65547, - 'asnLast': 65547, - 'ipFirst': '192.0.2.0', - 'ipLast': '192.0.2.128', - 'prefix': '192.0.2.0/25', - 'prefixLength': 25, -}] - -MOCK_RPSL_DB_RESULT = [{ - 'rpsl_pk': '192.0.2.0/25,AS65547', - 'object_class': 'route', - 'parsed_data': { - 'route': '192.0.2.0/25', - 'origin': ['AS65547'], - 'mnt-by': 'MNT-TEST', - }, - 'ip_first': '192.0.2.0', - 'ip_last': '192.0.2.128', - 'prefix_length': 25, - 'asn_first': 65547, - 'asn_last': 65547, - 'object_text': 'object text\nauth: CRYPT-PW LEuuhsBJNFV0Q', - 'rpki_status': RPKIStatus.not_found, - 'scopefilter_status': ScopeFilterStatus.out_scope_as, - 'route_preference_status': RoutePreferenceStatus.suppressed, - 'source': 'TEST1', - # only used in journal test - 'operation': DatabaseOperation.add_or_update, - 'origin': JournalEntryOrigin.auth_change, -}] +EXPECTED_RPSL_GRAPHQL_OUTPUT = [ + { + "rpslPk": "192.0.2.0/25,AS65547", + "objectClass": "route", + "objectText": "object text\nauth: CRYPT-PW DummyValue # Filtered for security", + "operation": DatabaseOperation.add_or_update, + "rpkiStatus": RPKIStatus.not_found, + "scopefilterStatus": ScopeFilterStatus.out_scope_as, + "routePreferenceStatus": RoutePreferenceStatus.suppressed, + "source": "TEST1", + "route": "192.0.2.0/25", + "origin": "AS65547", + "mntBy": "MNT-TEST", + "asn": 65547, + "asnFirst": 65547, + "asnLast": 65547, + "ipFirst": "192.0.2.0", + "ipLast": "192.0.2.128", + "prefix": "192.0.2.0/25", + "prefixLength": 25, + } +] + +MOCK_RPSL_DB_RESULT = [ + { + "rpsl_pk": "192.0.2.0/25,AS65547", + "object_class": "route", + "parsed_data": { + "route": "192.0.2.0/25", + "origin": ["AS65547"], + "mnt-by": "MNT-TEST", + }, + "ip_first": "192.0.2.0", + "ip_last": "192.0.2.128", + "prefix_length": 25, + "asn_first": 65547, + "asn_last": 65547, + "object_text": "object text\nauth: CRYPT-PW LEuuhsBJNFV0Q", + "rpki_status": RPKIStatus.not_found, + "scopefilter_status": ScopeFilterStatus.out_scope_as, + "route_preference_status": RoutePreferenceStatus.suppressed, + "source": "TEST1", + # only used in journal test + "operation": DatabaseOperation.add_or_update, + "origin": JournalEntryOrigin.auth_change, + } +] @pytest.fixture() def prepare_resolver(monkeypatch): - resolvers._collect_predicate_names = lambda info: ['asn', 'prefix', 'ipLast'] + resolvers._collect_predicate_names = lambda info: ["asn", "prefix", "ipLast"] mock_database_query = Mock(spec=RPSLDatabaseQuery) - monkeypatch.setattr('irrd.server.graphql.resolvers.RPSLDatabaseQuery', - lambda **kwargs: mock_database_query) + monkeypatch.setattr( + "irrd.server.graphql.resolvers.RPSLDatabaseQuery", lambda **kwargs: mock_database_query + ) mock_database_query.columns = RPSLDatabaseQuery.columns mock_query_resolver = Mock(spec=QueryResolver) - monkeypatch.setattr('irrd.server.graphql.resolvers.QueryResolver', - lambda preloader, database_handler: mock_query_resolver) - - app = Mock(state=Mock( - database_handler=Mock(spec=DatabaseHandler), - preloader=Mock(spec=Preloader), - )) + monkeypatch.setattr( + "irrd.server.graphql.resolvers.QueryResolver", lambda preloader, database_handler: mock_query_resolver + ) + + app = Mock( + state=Mock( + database_handler=Mock(spec=DatabaseHandler), + preloader=Mock(spec=Preloader), + ) + ) app.state.database_handler.execute_query = lambda query, refresh_on_error: MOCK_RPSL_DB_RESULT info = Mock() info.context = {} info.field_nodes = [Mock(selection_set=Mock(selections=Mock()))] - info.context['request'] = HTTPConnection({ - 'type': 'http', - 'client': ('127.0.0.1', '8000'), - 'app': app, - }) + info.context["request"] = HTTPConnection( + { + "type": "http", + "client": ("127.0.0.1", "8000"), + "app": app, + } + ) yield info, mock_database_query, mock_query_resolver @@ -99,110 +110,124 @@ def test_resolve_rpsl_objects(self, prepare_resolver, config_override): with pytest.raises(ValueError): resolvers.resolve_rpsl_objects(None, info) with pytest.raises(ValueError): - resolvers.resolve_rpsl_objects(None, info, object_class='route', sql_trace=True) + resolvers.resolve_rpsl_objects(None, info, object_class="route", sql_trace=True) with pytest.raises(ValueError): - resolvers.resolve_rpsl_objects(None, info, object_class='route', - rpki_status=[RPKIStatus.not_found], sql_trace=True) + resolvers.resolve_rpsl_objects( + None, info, object_class="route", rpki_status=[RPKIStatus.not_found], sql_trace=True + ) # Should not raise ValueError - resolvers.resolve_rpsl_objects(None, info, object_class='route', - rpki_status=[RPKIStatus.invalid], sql_trace=True) + resolvers.resolve_rpsl_objects( + None, info, object_class="route", rpki_status=[RPKIStatus.invalid], sql_trace=True + ) mock_database_query.reset_mock() - result = list(resolvers.resolve_rpsl_objects( - None, - info, - sql_trace=True, - rpsl_pk='pk', - object_class='route', - asn=[65550], - text_search='text', - rpki_status=[RPKIStatus.invalid], - scope_filter_status=[ScopeFilterStatus.out_scope_as], - route_preference_status=[RoutePreferenceStatus.suppressed], - ip_exact='192.0.2.1', - sources=['TEST1'], - mntBy='mnt-by', - unknownKwarg='ignored', - record_limit=2, - )) + result = list( + resolvers.resolve_rpsl_objects( + None, + info, + sql_trace=True, + rpsl_pk="pk", + object_class="route", + asn=[65550], + text_search="text", + rpki_status=[RPKIStatus.invalid], + scope_filter_status=[ScopeFilterStatus.out_scope_as], + route_preference_status=[RoutePreferenceStatus.suppressed], + ip_exact="192.0.2.1", + sources=["TEST1"], + mntBy="mnt-by", + unknownKwarg="ignored", + record_limit=2, + ) + ) assert result == EXPECTED_RPSL_GRAPHQL_OUTPUT assert flatten_mock_calls(mock_database_query) == [ - ['limit', (2,), {}], - ['rpsl_pks', ('pk',), {}], - ['object_classes', ('route',), {}], - ['asns_first', ([65550],), {}], - ['text_search', ('text',), {}], - ['rpki_status', ([RPKIStatus.invalid],), {}], - ['scopefilter_status', ([ScopeFilterStatus.out_scope_as],), {}], - ['route_preference_status', ([RoutePreferenceStatus.suppressed],), {}], - ['sources', (['TEST1'],), {}], - ['lookup_attrs_in', (['mnt-by'], 'mnt-by'), {}], - ['ip_exact', (IP('192.0.2.1'),), {}], + ["limit", (2,), {}], + ["rpsl_pks", ("pk",), {}], + ["object_classes", ("route",), {}], + ["asns_first", ([65550],), {}], + ["text_search", ("text",), {}], + ["rpki_status", ([RPKIStatus.invalid],), {}], + ["scopefilter_status", ([ScopeFilterStatus.out_scope_as],), {}], + ["route_preference_status", ([RoutePreferenceStatus.suppressed],), {}], + ["sources", (["TEST1"],), {}], + ["lookup_attrs_in", (["mnt-by"], "mnt-by"), {}], + ["ip_exact", (IP("192.0.2.1"),), {}], ] - assert info.context['sql_trace'] + assert info.context["sql_trace"] mock_database_query.reset_mock() - config_override({'sources_default': ['TEST1']}) - result = list(resolvers.resolve_rpsl_objects( - None, - info, - sql_trace=True, - rpsl_pk='pk', - )) + config_override({"sources_default": ["TEST1"]}) + result = list( + resolvers.resolve_rpsl_objects( + None, + info, + sql_trace=True, + rpsl_pk="pk", + ) + ) assert result == EXPECTED_RPSL_GRAPHQL_OUTPUT assert flatten_mock_calls(mock_database_query) == [ - ['rpsl_pks', ('pk',), {}], - ['rpki_status', ([RPKIStatus.not_found, RPKIStatus.valid],), {}], - ['scopefilter_status', ([ScopeFilterStatus.in_scope],), {}], - ['route_preference_status', ([RoutePreferenceStatus.visible],), {}], - ['sources', (['TEST1'],), {}], + ["rpsl_pks", ("pk",), {}], + ["rpki_status", ([RPKIStatus.not_found, RPKIStatus.valid],), {}], + ["scopefilter_status", ([ScopeFilterStatus.in_scope],), {}], + ["route_preference_status", ([RoutePreferenceStatus.visible],), {}], + ["sources", (["TEST1"],), {}], ] def test_strips_auth_attribute_hashes(self, prepare_resolver): info, mock_database_query, mock_query_resolver = prepare_resolver - rpsl_db_mntner_result = [{ - 'object_class': 'mntner', - 'parsed_data': { - 'auth': ['CRYPT-Pw LEuuhsBJNFV0Q'], - }, - }] - - info.context['request'].app.state.database_handler.execute_query = lambda query, refresh_on_error: rpsl_db_mntner_result - result = list(resolvers.resolve_rpsl_objects( - None, - info, - sql_trace=True, - rpsl_pk='pk', - )) - assert result == [{ - 'objectClass': 'mntner', - 'auth': ['CRYPT-Pw DummyValue # Filtered for security'], - }] + rpsl_db_mntner_result = [ + { + "object_class": "mntner", + "parsed_data": { + "auth": ["CRYPT-Pw LEuuhsBJNFV0Q"], + }, + } + ] + + info.context["request"].app.state.database_handler.execute_query = ( + lambda query, refresh_on_error: rpsl_db_mntner_result + ) + result = list( + resolvers.resolve_rpsl_objects( + None, + info, + sql_trace=True, + rpsl_pk="pk", + ) + ) + assert result == [ + { + "objectClass": "mntner", + "auth": ["CRYPT-Pw DummyValue # Filtered for security"], + } + ] def test_resolve_rpsl_object_mnt_by_objs(self, prepare_resolver): info, mock_database_query, mock_query_resolver = prepare_resolver mock_rpsl_object = { - 'objectClass': 'route', - 'mntBy': 'mntBy', - 'source': 'source', + "objectClass": "route", + "mntBy": "mntBy", + "source": "source", } result = list(resolvers.resolve_rpsl_object_mnt_by_objs(mock_rpsl_object, info)) assert result == EXPECTED_RPSL_GRAPHQL_OUTPUT assert flatten_mock_calls(mock_database_query) == [ - ['object_classes', (['mntner'],), {}], - ['rpsl_pks', (['mntBy'],), {}], - ['sources', (['source'],), {}], + ["object_classes", (["mntner"],), {}], + ["rpsl_pks", (["mntBy"],), {}], + ["sources", (["source"],), {}], ] # Missing PK mock_rpsl_object = { - 'objectClass': 'route', - 'source': 'source', + "objectClass": "route", + "source": "source", } assert not list(resolvers.resolve_rpsl_object_mnt_by_objs(mock_rpsl_object, info)) @@ -210,188 +235,209 @@ def test_resolve_rpsl_object_adminc_objs(self, prepare_resolver): info, mock_database_query, mock_query_resolver = prepare_resolver mock_rpsl_object = { - 'objectClass': 'route', - 'adminC': 'adminC', - 'source': 'source', + "objectClass": "route", + "adminC": "adminC", + "source": "source", } result = list(resolvers.resolve_rpsl_object_adminc_objs(mock_rpsl_object, info)) assert result == EXPECTED_RPSL_GRAPHQL_OUTPUT assert flatten_mock_calls(mock_database_query) == [ - ['object_classes', (['role', 'person'],), {}], - ['rpsl_pks', (['adminC'],), {}], - ['sources', (['source'],), {}], + ["object_classes", (["role", "person"],), {}], + ["rpsl_pks", (["adminC"],), {}], + ["sources", (["source"],), {}], ] def test_resolve_rpsl_object_techc_objs(self, prepare_resolver): info, mock_database_query, mock_query_resolver = prepare_resolver mock_rpsl_object = { - 'objectClass': 'route', - 'techC': 'techC', - 'source': 'source', + "objectClass": "route", + "techC": "techC", + "source": "source", } result = list(resolvers.resolve_rpsl_object_techc_objs(mock_rpsl_object, info)) assert result == EXPECTED_RPSL_GRAPHQL_OUTPUT assert flatten_mock_calls(mock_database_query) == [ - ['object_classes', (['role', 'person'],), {}], - ['rpsl_pks', (['techC'],), {}], - ['sources', (['source'],), {}], + ["object_classes", (["role", "person"],), {}], + ["rpsl_pks", (["techC"],), {}], + ["sources", (["source"],), {}], ] def test_resolve_rpsl_object_members_by_ref_objs(self, prepare_resolver): info, mock_database_query, mock_query_resolver = prepare_resolver mock_rpsl_object = { - 'objectClass': 'route', - 'mbrsByRef': 'mbrsByRef', - 'source': 'source', + "objectClass": "route", + "mbrsByRef": "mbrsByRef", + "source": "source", } result = list(resolvers.resolve_rpsl_object_members_by_ref_objs(mock_rpsl_object, info)) assert result == EXPECTED_RPSL_GRAPHQL_OUTPUT assert flatten_mock_calls(mock_database_query) == [ - ['object_classes', (['mntner'],), {}], - ['rpsl_pks', (['mbrsByRef'],), {}], - ['sources', (['source'],), {}], + ["object_classes", (["mntner"],), {}], + ["rpsl_pks", (["mbrsByRef"],), {}], + ["sources", (["source"],), {}], ] def test_resolve_rpsl_object_member_of_objs(self, prepare_resolver): info, mock_database_query, mock_query_resolver = prepare_resolver mock_rpsl_object = { - 'objectClass': 'route', - 'memberOf': 'memberOf', - 'source': 'source', + "objectClass": "route", + "memberOf": "memberOf", + "source": "source", } result = list(resolvers.resolve_rpsl_object_member_of_objs(mock_rpsl_object, info)) assert result == EXPECTED_RPSL_GRAPHQL_OUTPUT assert flatten_mock_calls(mock_database_query) == [ - ['object_classes', (['route-set'],), {}], - ['rpsl_pks', (['memberOf'],), {}], - ['sources', (['source'],), {}], + ["object_classes", (["route-set"],), {}], + ["rpsl_pks", (["memberOf"],), {}], + ["sources", (["source"],), {}], ] def test_resolve_rpsl_object_members_objs(self, prepare_resolver): info, mock_database_query, mock_query_resolver = prepare_resolver mock_rpsl_object = { - 'objectClass': 'as-set', - 'members': 'members', - 'source': 'source', + "objectClass": "as-set", + "members": "members", + "source": "source", } result = list(resolvers.resolve_rpsl_object_members_objs(mock_rpsl_object, info)) assert result == EXPECTED_RPSL_GRAPHQL_OUTPUT assert flatten_mock_calls(mock_database_query) == [ - ['object_classes', (['as-set'],), {}], - ['rpsl_pks', (['members'],), {}] + ["object_classes", (["as-set"],), {}], + ["rpsl_pks", (["members"],), {}], ] mock_database_query.reset_mock() mock_rpsl_object = { - 'objectClass': 'rtr-set', - 'members': 'members', - 'source': 'source', + "objectClass": "rtr-set", + "members": "members", + "source": "source", } result = list(resolvers.resolve_rpsl_object_members_objs(mock_rpsl_object, info)) assert result == EXPECTED_RPSL_GRAPHQL_OUTPUT assert flatten_mock_calls(mock_database_query) == [ - ['object_classes', (['rtr-set'],), {}], - ['rpsl_pks', (['members'],), {}] + ["object_classes", (["rtr-set"],), {}], + ["rpsl_pks", (["members"],), {}], ] def test_resolve_rpsl_object_journal(self, prepare_resolver, monkeypatch, config_override): info, mock_database_query, mock_query_resolver = prepare_resolver mock_journal_query = Mock(spec=RPSLDatabaseJournalQuery) - monkeypatch.setattr('irrd.server.graphql.resolvers.RPSLDatabaseJournalQuery', - lambda **kwargs: mock_journal_query) + monkeypatch.setattr( + "irrd.server.graphql.resolvers.RPSLDatabaseJournalQuery", lambda **kwargs: mock_journal_query + ) mock_rpsl_object = { - 'rpslPk': 'pk', - 'source': 'source', + "rpslPk": "pk", + "source": "source", } with pytest.raises(GraphQLError): next(resolvers.resolve_rpsl_object_journal(mock_rpsl_object, info)) - config_override({ - 'access_lists': {'localhost': ['127.0.0.1']}, - 'sources': {'source': {'nrtm_access_list': 'localhost'}}, - }) + config_override( + { + "access_lists": {"localhost": ["127.0.0.1"]}, + "sources": {"source": {"nrtm_access_list": "localhost"}}, + } + ) result = list(resolvers.resolve_rpsl_object_journal(mock_rpsl_object, info)) assert len(result) == 1 - assert result[0]['origin'] == 'auth_change' - assert 'CRYPT-PW DummyValue # Filtered for security' in result[0]['objectText'] + assert result[0]["origin"] == "auth_change" + assert "CRYPT-PW DummyValue # Filtered for security" in result[0]["objectText"] assert flatten_mock_calls(mock_journal_query) == [ - ['sources', (['source'],), {}], ['rpsl_pk', ('pk',), {}] + ["sources", (["source"],), {}], + ["rpsl_pk", ("pk",), {}], ] def test_resolve_database_status(self, prepare_resolver): info, mock_database_query, mock_query_resolver = prepare_resolver mock_status = { - 'SOURCE1': {'status_field': 1}, - 'SOURCE2': {'status_field': 2}, + "SOURCE1": {"status_field": 1}, + "SOURCE2": {"status_field": 2}, } mock_query_resolver.database_status = lambda sources: mock_status result = list(resolvers.resolve_database_status(None, info)) - assert result[0]['source'] == 'SOURCE1' - assert result[0]['statusField'] == 1 - assert result[1]['source'] == 'SOURCE2' - assert result[1]['statusField'] == 2 + assert result[0]["source"] == "SOURCE1" + assert result[0]["statusField"] == 1 + assert result[1]["source"] == "SOURCE2" + assert result[1]["statusField"] == 2 def test_resolve_asn_prefixes(self, prepare_resolver): info, mock_database_query, mock_query_resolver = prepare_resolver - mock_query_resolver.routes_for_origin = lambda asn, ip_version: [f'prefix-{asn}'] - - result = list(resolvers.resolve_asn_prefixes( - None, - info, - asns=[65550, 65551], - ip_version=4, - )) + mock_query_resolver.routes_for_origin = lambda asn, ip_version: [f"prefix-{asn}"] + + result = list( + resolvers.resolve_asn_prefixes( + None, + info, + asns=[65550, 65551], + ip_version=4, + ) + ) assert result == [ - {'asn': 65550, 'prefixes': ['prefix-AS65550']}, - {'asn': 65551, 'prefixes': ['prefix-AS65551']}, + {"asn": 65550, "prefixes": ["prefix-AS65550"]}, + {"asn": 65551, "prefixes": ["prefix-AS65551"]}, ] mock_query_resolver.set_query_sources.assert_called_once() def test_resolve_as_set_prefixes(self, prepare_resolver): info, mock_database_query, mock_query_resolver = prepare_resolver - mock_query_resolver.routes_for_as_set = lambda set_name, ip_version, exclude_sets: [f'prefix-{set_name}'] - - result = list(resolvers.resolve_as_set_prefixes( - None, - info, - set_names=['as-A', 'AS-B'], - ip_version=4, - sql_trace=True, - )) - assert sorted(result, key=str) == sorted([ - {'rpslPk': 'AS-A', 'prefixes': ['prefix-AS-A']}, - {'rpslPk': 'AS-B', 'prefixes': ['prefix-AS-B']}, - ], key=str) + mock_query_resolver.routes_for_as_set = lambda set_name, ip_version, exclude_sets: [ + f"prefix-{set_name}" + ] + + result = list( + resolvers.resolve_as_set_prefixes( + None, + info, + set_names=["as-A", "AS-B"], + ip_version=4, + sql_trace=True, + ) + ) + assert sorted(result, key=str) == sorted( + [ + {"rpslPk": "AS-A", "prefixes": ["prefix-AS-A"]}, + {"rpslPk": "AS-B", "prefixes": ["prefix-AS-B"]}, + ], + key=str, + ) mock_query_resolver.set_query_sources.assert_called_once() def test_resolve_recursive_set_members(self, prepare_resolver): info, mock_database_query, mock_query_resolver = prepare_resolver - mock_query_resolver.members_for_set_per_source = lambda set_name, exclude_sets, depth, recursive: {'TEST1': [f'member1-{set_name}'], 'TEST2': [f'member2-{set_name}']} - - result = list(resolvers.resolve_recursive_set_members( - None, - info, - set_names=['as-A', 'AS-B'], - depth=4, - sql_trace=True, - )) - assert sorted(result, key=str) == sorted([ - {'rpslPk': 'AS-A', 'rootSource': 'TEST1', 'members': ['member1-AS-A']}, - {'rpslPk': 'AS-A', 'rootSource': 'TEST2', 'members': ['member2-AS-A']}, - {'rpslPk': 'AS-B', 'rootSource': 'TEST1', 'members': ['member1-AS-B']}, - {'rpslPk': 'AS-B', 'rootSource': 'TEST2', 'members': ['member2-AS-B']}, - ], key=str) + mock_query_resolver.members_for_set_per_source = lambda set_name, exclude_sets, depth, recursive: { + "TEST1": [f"member1-{set_name}"], + "TEST2": [f"member2-{set_name}"], + } + + result = list( + resolvers.resolve_recursive_set_members( + None, + info, + set_names=["as-A", "AS-B"], + depth=4, + sql_trace=True, + ) + ) + assert sorted(result, key=str) == sorted( + [ + {"rpslPk": "AS-A", "rootSource": "TEST1", "members": ["member1-AS-A"]}, + {"rpslPk": "AS-A", "rootSource": "TEST2", "members": ["member2-AS-A"]}, + {"rpslPk": "AS-B", "rootSource": "TEST1", "members": ["member1-AS-B"]}, + {"rpslPk": "AS-B", "rootSource": "TEST2", "members": ["member2-AS-B"]}, + ], + key=str, + ) mock_query_resolver.set_query_sources.assert_called_once() diff --git a/irrd/server/graphql/tests/test_schema_generator.py b/irrd/server/graphql/tests/test_schema_generator.py index d0fa1ae62..d430e74aa 100644 --- a/irrd/server/graphql/tests/test_schema_generator.py +++ b/irrd/server/graphql/tests/test_schema_generator.py @@ -4,10 +4,12 @@ def test_schema_generator(): # This test will need updating if changes are made to RPSL types. generator = SchemaGenerator() - assert generator.graphql_types['RPSLAsBlock']['descr'] == '[String!]' - assert generator.graphql_types['RPSLAsBlock']['techCObjs'] == '[RPSLContactUnion!]' - assert generator.graphql_types['RPSLRtrSet']['rtr-set'] == 'String' - assert generator.type_defs == """enum RPKIStatus { + assert generator.graphql_types["RPSLAsBlock"]["descr"] == "[String!]" + assert generator.graphql_types["RPSLAsBlock"]["techCObjs"] == "[RPSLContactUnion!]" + assert generator.graphql_types["RPSLRtrSet"]["rtr-set"] == "String" + assert ( + generator.type_defs + == """enum RPKIStatus { valid invalid not_found @@ -568,3 +570,4 @@ def test_schema_generator(): } union RPSLContactUnion = RPSLPerson | RPSLRole""" + ) diff --git a/irrd/server/http/app.py b/irrd/server/http/app.py index 6f7078cd3..ccf2b86b8 100644 --- a/irrd/server/http/app.py +++ b/irrd/server/http/app.py @@ -14,15 +14,18 @@ from irrd import ENV_MAIN_PROCESS_PID from irrd.conf import config_init from irrd.server.graphql import ENV_UVICORN_WORKER_CONFIG_PATH -from irrd.server.graphql.extensions import error_formatter, QueryMetadataExtension +from irrd.server.graphql.extensions import QueryMetadataExtension, error_formatter from irrd.server.graphql.schema_builder import build_executable_schema from irrd.server.http.endpoints import ( + ObjectSubmissionEndpoint, StatusEndpoint, SuspensionSubmissionEndpoint, WhoisQueryEndpoint, - ObjectSubmissionEndpoint, ) -from irrd.server.http.event_stream import EventStreamEndpoint, EventStreamInitialDownloadEndpoint +from irrd.server.http.event_stream import ( + EventStreamEndpoint, + EventStreamInitialDownloadEndpoint, +) from irrd.storage.database_handler import DatabaseHandler from irrd.storage.preload import Preloader from irrd.utils.process_support import memory_trim, set_traceback_handler @@ -54,8 +57,10 @@ async def startup(): app.state.preloader = Preloader(enable_queries=True) except Exception as e: logger.critical( - "HTTP worker failed to initialise preloader or database, " - f"unable to start, terminating IRRd, traceback follows: {e}", + ( + "HTTP worker failed to initialise preloader or database, " + f"unable to start, terminating IRRd, traceback follows: {e}" + ), exc_info=e, ) main_pid = os.getenv(ENV_MAIN_PROCESS_PID) diff --git a/irrd/server/http/endpoints.py b/irrd/server/http/endpoints.py index 197597147..658de79f6 100644 --- a/irrd/server/http/endpoints.py +++ b/irrd/server/http/endpoints.py @@ -7,14 +7,15 @@ from asgiref.sync import sync_to_async from starlette.endpoints import HTTPEndpoint from starlette.requests import Request -from starlette.responses import PlainTextResponse, Response, JSONResponse +from starlette.responses import JSONResponse, PlainTextResponse, Response from irrd.server.access_check import is_client_permitted from irrd.updates.handler import ChangeSubmissionHandler from irrd.utils.validators import RPSLChangeSubmission, RPSLSuspensionSubmission -from .status_generator import StatusGenerator + from ..whois.query_parser import WhoisQueryParser from ..whois.query_response import WhoisQueryResponseType +from .status_generator import StatusGenerator logger = logging.getLogger(__name__) @@ -22,8 +23,8 @@ class StatusEndpoint(HTTPEndpoint): def get(self, request: Request) -> Response: assert request.client - if not is_client_permitted(request.client.host, 'server.http.status_access_list'): - return PlainTextResponse('Access denied', status_code=403) + if not is_client_permitted(request.client.host, "server.http.status_access_list"): + return PlainTextResponse("Access denied", status_code=403) response = StatusGenerator().generate_status() return PlainTextResponse(response) @@ -33,24 +34,22 @@ class WhoisQueryEndpoint(HTTPEndpoint): def get(self, request: Request) -> Response: assert request.client start_time = time.perf_counter() - if 'q' not in request.query_params: + if "q" not in request.query_params: return PlainTextResponse('Missing required query parameter "q"', status_code=400) - client_str = request.client.host + ':' + str(request.client.port) - query = request.query_params['q'] + client_str = request.client.host + ":" + str(request.client.port) + query = request.query_params["q"] parser = WhoisQueryParser( - request.client.host, - client_str, - request.app.state.preloader, - request.app.state.database_handler + request.client.host, client_str, request.app.state.preloader, request.app.state.database_handler ) response = parser.handle_query(query) response.clean_response() elapsed = time.perf_counter() - start_time length = len(response.result) if response.result else 0 - logger.info(f'{client_str}: sent answer to HTTP query, elapsed {elapsed:.9f}s, ' - f'{length} chars: {query}') + logger.info( + f"{client_str}: sent answer to HTTP query, elapsed {elapsed:.9f}s, {length} chars: {query}" + ) if response.response_type == WhoisQueryResponseType.ERROR_INTERNAL: return PlainTextResponse(response.result, status_code=500) @@ -78,13 +77,13 @@ async def _handle_submission(self, request: Request, delete=False): return PlainTextResponse(str(error), status_code=400) try: - meta_json = request.headers['X-irrd-metadata'] + meta_json = request.headers["X-irrd-metadata"] request_meta = json.loads(meta_json) except (JSONDecodeError, KeyError): request_meta = {} - request_meta['HTTP-client-IP'] = request.client.host - request_meta['HTTP-User-Agent'] = request.headers.get('User-Agent') + request_meta["HTTP-client-IP"] = request.client.host + request_meta["HTTP-User-Agent"] = request.headers.get("User-Agent") handler = ChangeSubmissionHandler() await sync_to_async(handler.load_change_submission)( @@ -104,11 +103,9 @@ async def post(self, request: Request) -> Response: return PlainTextResponse(str(error), status_code=400) request_meta = { - 'HTTP-client-IP': request.client.host, - 'HTTP-User-Agent': request.headers.get('User-Agent'), + "HTTP-client-IP": request.client.host, + "HTTP-User-Agent": request.headers.get("User-Agent"), } handler = ChangeSubmissionHandler() - await sync_to_async(handler.load_suspension_submission)( - data=data, request_meta=request_meta - ) + await sync_to_async(handler.load_suspension_submission)(data=data, request_meta=request_meta) return JSONResponse(handler.submitter_report_json()) diff --git a/irrd/server/http/event_stream.py b/irrd/server/http/event_stream.py index bf612f1c4..7ee67bf1d 100644 --- a/irrd/server/http/event_stream.py +++ b/irrd/server/http/event_stream.py @@ -5,30 +5,32 @@ import socket import sys import tempfile -from typing import Any, List, Optional, Callable +from typing import Any, Callable, List, Literal, Optional import pydantic import ujson -from starlette.endpoints import WebSocketEndpoint, HTTPEndpoint +from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint from starlette.requests import Request -from starlette.responses import Response, StreamingResponse, PlainTextResponse +from starlette.responses import PlainTextResponse, Response, StreamingResponse from starlette.status import WS_1003_UNSUPPORTED_DATA, WS_1008_POLICY_VIOLATION from starlette.websockets import WebSocket -from typing import Literal from irrd.conf import get_setting +from irrd.routepref.status import RoutePreferenceStatus from irrd.rpki.status import RPKIStatus from irrd.rpsl.rpsl_objects import rpsl_object_from_text -from irrd.routepref.status import RoutePreferenceStatus from irrd.scopefilter.status import ScopeFilterStatus from irrd.server.access_check import is_client_permitted from irrd.storage.database_handler import DatabaseHandler -from irrd.storage.event_stream import AsyncEventStreamRedisClient, REDIS_STREAM_END_IDENTIFIER +from irrd.storage.event_stream import ( + REDIS_STREAM_END_IDENTIFIER, + AsyncEventStreamRedisClient, +) from irrd.storage.queries import ( DatabaseStatusQuery, - RPSLDatabaseQuery, - RPSLDatabaseJournalStatisticsQuery, RPSLDatabaseJournalQuery, + RPSLDatabaseJournalStatisticsQuery, + RPSLDatabaseQuery, ) from irrd.utils.text import remove_auth_hashes from irrd.vendor import postgres_copy @@ -152,7 +154,9 @@ async def on_connect(self, websocket: WebSocket) -> None: await self.send_header() async def send_header(self) -> None: - journaled_sources = [name for name, settings in get_setting("sources").items() if settings.get("keep_journal")] + journaled_sources = [ + name for name, settings in get_setting("sources").items() if settings.get("keep_journal") + ] dh = DatabaseHandler(readonly=True) query = DatabaseStatusQuery().sources(journaled_sources) sources_created = {row["source"]: row["created"].isoformat() for row in dh.execute_query(query)} @@ -253,7 +257,9 @@ def __init__( async def _run_monitor(self) -> None: after_redis_event_id = REDIS_STREAM_END_IDENTIFIER - logger.info(f"event stream {self.host}: sending entries from global serial {self.after_global_serial}") + logger.info( + f"event stream {self.host}: sending entries from global serial {self.after_global_serial}" + ) await self._send_new_journal_entries() logger.debug(f"event stream {self.host}: initial send complete, waiting for new events") while True: @@ -300,7 +306,9 @@ async def _send_new_journal_entries(self): ) self.after_global_serial = max([entry["serial_global"], self.after_global_serial]) - logger.debug(f"event stream {self.host}: sent new changes up to global serial {self.after_global_serial}") + logger.debug( + f"event stream {self.host}: sent new changes up to global serial {self.after_global_serial}" + ) async def close(self): if self.streaming_task.done(): diff --git a/irrd/server/http/server.py b/irrd/server/http/server.py index 9247ec043..a3775f993 100644 --- a/irrd/server/http/server.py +++ b/irrd/server/http/server.py @@ -13,23 +13,22 @@ sys.path.append(str(Path(__file__).resolve().parents[3])) from irrd import __version__ -from irrd.conf import get_setting, get_configuration - +from irrd.conf import get_configuration, get_setting from irrd.server.graphql import ENV_UVICORN_WORKER_CONFIG_PATH def run_http_server(config_path: str): - setproctitle('irrd-http-server-manager') + setproctitle("irrd-http-server-manager") configuration = get_configuration() assert configuration os.environ[ENV_UVICORN_WORKER_CONFIG_PATH] = config_path uvicorn.run( app="irrd.server.http.app:app", - host=get_setting('server.http.interface'), - port=get_setting('server.http.port'), - workers=get_setting('server.http.workers'), - forwarded_allow_ips=get_setting('server.http.forwarded_allowed_ips'), - headers=[('Server', f'IRRd {__version__}')], + host=get_setting("server.http.interface"), + port=get_setting("server.http.port"), + workers=get_setting("server.http.workers"), + forwarded_allow_ips=get_setting("server.http.forwarded_allowed_ips"), + headers=[("Server", f"IRRd {__version__}")], log_config=configuration.logging_config, ws_ping_interval=60, ws_ping_timeout=60, diff --git a/irrd/server/http/status_generator.py b/irrd/server/http/status_generator.py index 2a34fad65..fa2c100bf 100644 --- a/irrd/server/http/status_generator.py +++ b/irrd/server/http/status_generator.py @@ -16,7 +16,6 @@ class StatusGenerator: - def generate_status(self) -> str: """ Generate a human-readable overview of database status. @@ -31,44 +30,46 @@ def generate_status(self) -> str: results = [ self._generate_header(), self._generate_statistics_table(), - self._generate_source_detail(database_handler) + self._generate_source_detail(database_handler), ] database_handler.close() - return '\n\n'.join(results) + return "\n\n".join(results) def _generate_header(self) -> str: """ Generate the header of the report, containing basic info like version and time until the next mirror update. """ - return textwrap.dedent(f""" + return textwrap.dedent( + f""" IRRD version {__version__} Listening on {get_setting('server.whois.interface')} port {get_setting('server.whois.port')} - """).lstrip() + """ + ).lstrip() def _generate_statistics_table(self) -> str: """ Generate a table with an overview of basic stats for each database. """ table = BeautifulTable(default_alignment=BeautifulTable.ALIGN_RIGHT) - table.column_headers = ['source', 'total obj', 'rt obj', 'aut-num obj', 'serial', 'last export'] - table.column_alignments['source'] = BeautifulTable.ALIGN_LEFT - table.left_border_char = table.right_border_char = '' - table.right_border_char = table.bottom_border_char = '' - table.row_separator_char = '' - table.column_separator_char = ' ' + table.column_headers = ["source", "total obj", "rt obj", "aut-num obj", "serial", "last export"] + table.column_alignments["source"] = BeautifulTable.ALIGN_LEFT + table.left_border_char = table.right_border_char = "" + table.right_border_char = table.bottom_border_char = "" + table.row_separator_char = "" + table.column_separator_char = " " for status_result in self.status_results: - source = status_result['source'].upper() + source = status_result["source"].upper() total_obj, route_obj, autnum_obj = self._statistics_for_source(source) - serial = status_result['serial_newest_seen'] - last_export = status_result['serial_last_export'] + serial = status_result["serial_newest_seen"] + last_export = status_result["serial_last_export"] if not last_export: - last_export = '' + last_export = "" table.append_row([source, total_obj, route_obj, autnum_obj, serial, last_export]) total_obj, route_obj, autnum_obj = self._statistics_for_source(None) - table.append_row(['TOTAL', total_obj, route_obj, autnum_obj, '', '']) + table.append_row(["TOTAL", total_obj, route_obj, autnum_obj, "", ""]) return str(table) @@ -79,13 +80,13 @@ def _statistics_for_source(self, source: Optional[str]): If source is None, all sources are counted. """ if source: - source_statistics = [s for s in self.statistics_results if s['source'] == source] + source_statistics = [s for s in self.statistics_results if s["source"] == source] else: source_statistics = self.statistics_results - total_obj = sum([s['count'] for s in source_statistics]) - route_obj = sum([s['count'] for s in source_statistics if s['object_class'] == 'route']) - autnum_obj = sum([s['count'] for s in source_statistics if s['object_class'] == 'aut-num']) + total_obj = sum([s["count"] for s in source_statistics]) + route_obj = sum([s["count"] for s in source_statistics if s["object_class"] == "route"]) + autnum_obj = sum([s["count"] for s in source_statistics if s["object_class"] == "aut-num"]) return total_obj, route_obj, autnum_obj def _generate_source_detail(self, database_handler: DatabaseHandler) -> str: @@ -97,26 +98,31 @@ def _generate_source_detail(self, database_handler: DatabaseHandler) -> str: queried by _generate_remote_status_info(). :param database_handler: """ - result_txt = '' + result_txt = "" for status_result in self.status_results: - source = status_result['source'].upper() - keep_journal = 'Yes' if get_setting(f'sources.{source}.keep_journal') else 'No' - authoritative = 'Yes' if get_setting(f'sources.{source}.authoritative') else 'No' - object_class_filter = get_setting(f'sources.{source}.object_class_filter') - rpki_enabled = get_setting('rpki.roa_source') and not get_setting(f'sources.{source}.rpki_excluded') - rpki_enabled_str = 'Yes' if rpki_enabled else 'No' - scopefilter_enabled = get_setting('scopefilter') and not get_setting(f'sources.{source}.scopefilter_excluded') - scopefilter_enabled_str = 'Yes' if scopefilter_enabled else 'No' - synchronised_serials_str = 'Yes' if is_serial_synchronised(database_handler, source) else 'No' - route_object_preference = get_setting(f'sources.{source}.route_object_preference') - - nrtm_host = get_setting(f'sources.{source}.nrtm_host') - nrtm_port = int(get_setting(f'sources.{source}.nrtm_port', DEFAULT_SOURCE_NRTM_PORT)) + source = status_result["source"].upper() + keep_journal = "Yes" if get_setting(f"sources.{source}.keep_journal") else "No" + authoritative = "Yes" if get_setting(f"sources.{source}.authoritative") else "No" + object_class_filter = get_setting(f"sources.{source}.object_class_filter") + rpki_enabled = get_setting("rpki.roa_source") and not get_setting( + f"sources.{source}.rpki_excluded" + ) + rpki_enabled_str = "Yes" if rpki_enabled else "No" + scopefilter_enabled = get_setting("scopefilter") and not get_setting( + f"sources.{source}.scopefilter_excluded" + ) + scopefilter_enabled_str = "Yes" if scopefilter_enabled else "No" + synchronised_serials_str = "Yes" if is_serial_synchronised(database_handler, source) else "No" + route_object_preference = get_setting(f"sources.{source}.route_object_preference") + + nrtm_host = get_setting(f"sources.{source}.nrtm_host") + nrtm_port = int(get_setting(f"sources.{source}.nrtm_port", DEFAULT_SOURCE_NRTM_PORT)) remote_information = self._generate_remote_status_info(nrtm_host, nrtm_port, source) - remote_information = textwrap.indent(remote_information, ' ' * 16) + remote_information = textwrap.indent(remote_information, " " * 16) - result_txt += textwrap.dedent(f""" + result_txt += textwrap.dedent( + f""" Status for {source} ------------------- Local information: @@ -137,7 +143,8 @@ def _generate_source_detail(self, database_handler: DatabaseHandler) -> str: Route object preference: {route_object_preference} Remote information:{remote_information} - """) + """ + ) return result_txt def _generate_remote_status_info(self, nrtm_host: Optional[str], nrtm_port: int, source: str) -> str: @@ -152,26 +159,34 @@ def _generate_remote_status_info(self, nrtm_host: Optional[str], nrtm_port: int, try: source_status = whois_query_source_status(nrtm_host, nrtm_port, source) mirrorable, mirror_serial_oldest, mirror_serial_newest, mirror_export_serial = source_status - mirrorable_str = 'Yes' if mirrorable else 'No' + mirrorable_str = "Yes" if mirrorable else "No" - return textwrap.dedent(f""" + return textwrap.dedent( + f""" NRTM host: {nrtm_host} port {nrtm_port} Mirrorable: {mirrorable_str} Oldest journal serial number: {mirror_serial_oldest} Newest journal serial number: {mirror_serial_newest} Last export at serial number: {mirror_export_serial} - """) + """ + ) except ValueError: - return textwrap.dedent(f""" + return textwrap.dedent( + f""" NRTM host: {nrtm_host} port {nrtm_port} Remote status query unsupported or query failed - """) + """ + ) except (socket.timeout, ConnectionError): - return textwrap.dedent(f""" + return textwrap.dedent( + f""" NRTM host: {nrtm_host} port {nrtm_port} Unable to reach remote server for status query - """) + """ + ) else: - return textwrap.dedent(""" + return textwrap.dedent( + """ No NRTM host configured. - """) + """ + ) diff --git a/irrd/server/http/tests/test_endpoints.py b/irrd/server/http/tests/test_endpoints.py index 987fd79b7..a22f64d8f 100644 --- a/irrd/server/http/tests/test_endpoints.py +++ b/irrd/server/http/tests/test_endpoints.py @@ -8,109 +8,128 @@ from irrd.storage.preload import Preloader from irrd.updates.handler import ChangeSubmissionHandler from irrd.utils.validators import RPSLChangeSubmission, RPSLSuspensionSubmission + +from ...whois.query_parser import WhoisQueryParser +from ...whois.query_response import ( + WhoisQueryResponse, + WhoisQueryResponseMode, + WhoisQueryResponseType, +) from ..app import app from ..endpoints import StatusEndpoint, WhoisQueryEndpoint from ..status_generator import StatusGenerator -from ...whois.query_parser import WhoisQueryParser -from ...whois.query_response import WhoisQueryResponse, WhoisQueryResponseType, \ - WhoisQueryResponseMode class TestStatusEndpoint: def setup_method(self): - self.mock_request = HTTPConnection({ - 'type': 'http', - 'client': ('127.0.0.1', '8000'), - }) + self.mock_request = HTTPConnection( + { + "type": "http", + "client": ("127.0.0.1", "8000"), + } + ) self.endpoint = StatusEndpoint(scope=self.mock_request, receive=None, send=None) def test_status_no_access_list(self): response = self.endpoint.get(self.mock_request) assert response.status_code == 403 - assert response.body == b'Access denied' + assert response.body == b"Access denied" def test_status_access_list_permitted(self, config_override, monkeypatch): - config_override({ - 'server': { - 'http': { - 'status_access_list': 'test_access_list', - } - }, - 'access_lists': { - 'test_access_list': { - '127.0.0.0/25', - } - }, - }) + config_override( + { + "server": { + "http": { + "status_access_list": "test_access_list", + } + }, + "access_lists": { + "test_access_list": { + "127.0.0.0/25", + } + }, + } + ) mock_database_status_generator = Mock(spec=StatusGenerator) - monkeypatch.setattr('irrd.server.http.endpoints.StatusGenerator', - lambda: mock_database_status_generator) - mock_database_status_generator.generate_status = lambda: 'status' + monkeypatch.setattr( + "irrd.server.http.endpoints.StatusGenerator", lambda: mock_database_status_generator + ) + mock_database_status_generator.generate_status = lambda: "status" response = self.endpoint.get(self.mock_request) assert response.status_code == 200 - assert response.body == b'status' + assert response.body == b"status" def test_status_access_list_denied(self, config_override): - config_override({ - 'server': { - 'http': { - 'status_access_list': 'test_access_list', - } - }, - 'access_lists': { - 'test_access_list': { - '192.0.2.0/25', - } - }, - }) + config_override( + { + "server": { + "http": { + "status_access_list": "test_access_list", + } + }, + "access_lists": { + "test_access_list": { + "192.0.2.0/25", + } + }, + } + ) response = self.endpoint.get(self.mock_request) assert response.status_code == 403 - assert response.body == b'Access denied' + assert response.body == b"Access denied" class TestWhoisQueryEndpoint: def test_query_endpoint(self, monkeypatch): mock_query_parser = Mock(spec=WhoisQueryParser) - monkeypatch.setattr('irrd.server.http.endpoints.WhoisQueryParser', - lambda client_ip, client_str, preloader, database_handler: mock_query_parser) - app = Mock(state=Mock( - database_handler=Mock(spec=DatabaseHandler), - preloader=Mock(spec=Preloader), - )) - mock_request = HTTPConnection({ - 'type': 'http', - 'client': ('127.0.0.1', '8000'), - 'app': app, - 'query_string': '', - }) + monkeypatch.setattr( + "irrd.server.http.endpoints.WhoisQueryParser", + lambda client_ip, client_str, preloader, database_handler: mock_query_parser, + ) + app = Mock( + state=Mock( + database_handler=Mock(spec=DatabaseHandler), + preloader=Mock(spec=Preloader), + ) + ) + mock_request = HTTPConnection( + { + "type": "http", + "client": ("127.0.0.1", "8000"), + "app": app, + "query_string": "", + } + ) endpoint = WhoisQueryEndpoint(scope=mock_request, receive=None, send=None) result = endpoint.get(mock_request) assert result.status_code == 400 - assert result.body.startswith(b'Missing required query') - - mock_request = HTTPConnection({ - 'type': 'http', - 'client': ('127.0.0.1', '8000'), - 'app': app, - 'query_string': 'q=query', - }) + assert result.body.startswith(b"Missing required query") + + mock_request = HTTPConnection( + { + "type": "http", + "client": ("127.0.0.1", "8000"), + "app": app, + "query_string": "q=query", + } + ) mock_query_parser.handle_query = lambda query: WhoisQueryResponse( response_type=WhoisQueryResponseType.SUCCESS, mode=WhoisQueryResponseMode.IRRD, # irrelevant - result=f'result {query} 🦄' + result=f"result {query} 🦄", ) result = endpoint.get(mock_request) assert result.status_code == 200 - assert result.body.decode('utf-8') == 'result query 🦄' + assert result.body.decode("utf-8") == "result query 🦄" mock_query_parser.handle_query = lambda query: WhoisQueryResponse( response_type=WhoisQueryResponseType.KEY_NOT_FOUND, mode=WhoisQueryResponseMode.IRRD, # irrelevant - result='', + result="", ) result = endpoint.get(mock_request) assert result.status_code == 204 @@ -119,74 +138,77 @@ def test_query_endpoint(self, monkeypatch): mock_query_parser.handle_query = lambda query: WhoisQueryResponse( response_type=WhoisQueryResponseType.ERROR_USER, mode=WhoisQueryResponseMode.IRRD, # irrelevant - result=f'result {query} 🦄' + result=f"result {query} 🦄", ) result = endpoint.get(mock_request) assert result.status_code == 400 - assert result.body.decode('utf-8') == 'result query 🦄' + assert result.body.decode("utf-8") == "result query 🦄" mock_query_parser.handle_query = lambda query: WhoisQueryResponse( response_type=WhoisQueryResponseType.ERROR_INTERNAL, mode=WhoisQueryResponseMode.IRRD, # irrelevant - result=f'result {query} 🦄' + result=f"result {query} 🦄", ) result = endpoint.get(mock_request) assert result.status_code == 500 - assert result.body.decode('utf-8') == 'result query 🦄' + assert result.body.decode("utf-8") == "result query 🦄" class TestObjectSubmissionEndpoint: def test_endpoint(self, monkeypatch): mock_handler = Mock(spec=ChangeSubmissionHandler) - monkeypatch.setattr('irrd.server.http.endpoints.ChangeSubmissionHandler', - lambda: mock_handler) - mock_handler.submitter_report_json = lambda: {'response': True} + monkeypatch.setattr("irrd.server.http.endpoints.ChangeSubmissionHandler", lambda: mock_handler) + mock_handler.submitter_report_json = lambda: {"response": True} client = TestClient(app) data = { - 'objects': [ - {'attributes': [ - {'name': 'person', 'value': 'Placeholder Person Object'}, - {'name': 'nic-hdl', 'value': 'PERSON-TEST'}, - {'name': 'changed', 'value': 'changed@example.com 20190701 # comment'}, - {'name': 'source', 'value': 'TEST'}, - ]}, + "objects": [ + { + "attributes": [ + {"name": "person", "value": "Placeholder Person Object"}, + {"name": "nic-hdl", "value": "PERSON-TEST"}, + {"name": "changed", "value": "changed@example.com 20190701 # comment"}, + {"name": "source", "value": "TEST"}, + ] + }, ], - 'passwords': ['invalid1', 'invalid2'], + "passwords": ["invalid1", "invalid2"], } expected_data = RPSLChangeSubmission.parse_obj(data) - response_post = client.post('/v1/submit/', data=ujson.dumps(data), headers={'X-irrd-metadata': '{"meta": 2}'}) + response_post = client.post( + "/v1/submit/", data=ujson.dumps(data), headers={"X-irrd-metadata": '{"meta": 2}'} + ) assert response_post.status_code == 200 assert response_post.text == '{"response":true}' mock_handler.load_change_submission.assert_called_once_with( data=expected_data, delete=False, - request_meta={'HTTP-client-IP': 'testclient', 'HTTP-User-Agent': 'testclient', 'meta': 2}, + request_meta={"HTTP-client-IP": "testclient", "HTTP-User-Agent": "testclient", "meta": 2}, ) mock_handler.send_notification_target_reports.assert_called_once() mock_handler.reset_mock() - response_delete = client.delete('/v1/submit/', data=ujson.dumps(data)) + response_delete = client.delete("/v1/submit/", data=ujson.dumps(data)) assert response_delete.status_code == 200 assert response_delete.text == '{"response":true}' mock_handler.load_change_submission.assert_called_once_with( data=expected_data, delete=True, - request_meta={'HTTP-client-IP': 'testclient', 'HTTP-User-Agent': 'testclient'}, + request_meta={"HTTP-client-IP": "testclient", "HTTP-User-Agent": "testclient"}, ) mock_handler.send_notification_target_reports.assert_called_once() mock_handler.reset_mock() - response_invalid_format = client.post('/v1/submit/', data='{"invalid": true}') + response_invalid_format = client.post("/v1/submit/", data='{"invalid": true}') assert response_invalid_format.status_code == 400 - assert 'field required' in response_invalid_format.text + assert "field required" in response_invalid_format.text mock_handler.load_change_submission.assert_not_called() mock_handler.send_notification_target_reports.assert_not_called() - response_invalid_json = client.post('/v1/submit/', data='invalid') + response_invalid_json = client.post("/v1/submit/", data="invalid") assert response_invalid_json.status_code == 400 - assert 'expect' in response_invalid_json.text.lower() + assert "expect" in response_invalid_json.text.lower() mock_handler.load_change_submission.assert_not_called() mock_handler.send_notification_target_reports.assert_not_called() @@ -194,32 +216,29 @@ def test_endpoint(self, monkeypatch): class TestSuspensionSubmissionEndpoint: def test_endpoint(self, monkeypatch): mock_handler = Mock(spec=ChangeSubmissionHandler) - monkeypatch.setattr('irrd.server.http.endpoints.ChangeSubmissionHandler', - lambda: mock_handler) - mock_handler.submitter_report_json = lambda: {'response': True} + monkeypatch.setattr("irrd.server.http.endpoints.ChangeSubmissionHandler", lambda: mock_handler) + mock_handler.submitter_report_json = lambda: {"response": True} client = TestClient(app) data = { - "objects": [ - {"mntner": "DASHCARE-MNT", "source": "DASHCARE", "request_type": "reactivate"} - ], + "objects": [{"mntner": "DASHCARE-MNT", "source": "DASHCARE", "request_type": "reactivate"}], "override": "<>", } expected_data = RPSLSuspensionSubmission.parse_obj(data) - response_post = client.post('/v1/suspension/', data=ujson.dumps(data)) + response_post = client.post("/v1/suspension/", data=ujson.dumps(data)) assert response_post.status_code == 200 assert response_post.text == '{"response":true}' mock_handler.load_suspension_submission.assert_called_once_with( data=expected_data, - request_meta={'HTTP-client-IP': 'testclient', 'HTTP-User-Agent': 'testclient'}, + request_meta={"HTTP-client-IP": "testclient", "HTTP-User-Agent": "testclient"}, ) mock_handler.reset_mock() - response_invalid_format = client.post('/v1/suspension/', data='{"invalid": true}') + response_invalid_format = client.post("/v1/suspension/", data='{"invalid": true}') assert response_invalid_format.status_code == 400 - assert 'field required' in response_invalid_format.text + assert "field required" in response_invalid_format.text - response_invalid_json = client.post('/v1/suspension/', data='invalid') + response_invalid_json = client.post("/v1/suspension/", data="invalid") assert response_invalid_json.status_code == 400 - assert 'expect' in response_invalid_json.text.lower() + assert "expect" in response_invalid_json.text.lower() diff --git a/irrd/server/http/tests/test_event_stream.py b/irrd/server/http/tests/test_event_stream.py index a6d7956ea..7d7c4e361 100644 --- a/irrd/server/http/tests/test_event_stream.py +++ b/irrd/server/http/tests/test_event_stream.py @@ -18,13 +18,14 @@ from irrd.scopefilter.status import ScopeFilterStatus from irrd.storage.event_stream import OPERATION_JOURNAL_EXTENDED from irrd.storage.queries import ( - RPSLDatabaseJournalStatisticsQuery, RPSLDatabaseJournalQuery, + RPSLDatabaseJournalStatisticsQuery, RPSLDatabaseQuery, ) from irrd.utils.rpsl_samples import SAMPLE_MNTNER from irrd.utils.test_utils import MockDatabaseHandler from irrd.vendor import postgres_copy + from ..app import app from ..event_stream import AsyncEventStreamFollower @@ -75,7 +76,9 @@ def mock_copy_to_side_effect(source, dest, engine_or_conn, format): monkeypatch.setattr("irrd.server.http.event_stream.postgres_copy.copy_to", mock_copy_to) client = TestClient(app) - response = client.get("/v1/event-stream/initial/", params={"sources": "TEST", "object_classes": "mntner"}) + response = client.get( + "/v1/event-stream/initial/", params={"sources": "TEST", "object_classes": "mntner"} + ) assert response.status_code == 200 header, rpsl_obj = (json.loads(line) for line in response.text.splitlines()) @@ -150,7 +153,9 @@ async def test_endpoint(self, monkeypatch, config_override): mock_dh = MockDatabaseHandler() mock_dh.reset_mock() monkeypatch.setattr("irrd.server.http.event_stream.DatabaseHandler", MockDatabaseHandler) - monkeypatch.setattr("irrd.server.http.event_stream.AsyncEventStreamFollower", mock_event_stream_follower) + monkeypatch.setattr( + "irrd.server.http.event_stream.AsyncEventStreamFollower", mock_event_stream_follower + ) client = TestClient(app) with client.websocket_connect("/v1/event-stream/") as websocket: @@ -229,8 +234,12 @@ async def message_callback(message): assert mock_dh.readonly assert mock_dh.closed assert mock_dh.queries[0] == RPSLDatabaseJournalStatisticsQuery() - assert mock_dh.queries[1] == RPSLDatabaseJournalQuery().serial_global_range(expected_serial_starts.pop(0)) - assert mock_dh.queries[2] == RPSLDatabaseJournalQuery().serial_global_range(expected_serial_starts.pop(0)) + assert mock_dh.queries[1] == RPSLDatabaseJournalQuery().serial_global_range( + expected_serial_starts.pop(0) + ) + assert mock_dh.queries[2] == RPSLDatabaseJournalQuery().serial_global_range( + expected_serial_starts.pop(0) + ) msg_journal1, event_journal_extended = messages diff --git a/irrd/server/http/tests/test_status_generator.py b/irrd/server/http/tests/test_status_generator.py index 023f5ce0a..49de72046 100644 --- a/irrd/server/http/tests/test_status_generator.py +++ b/irrd/server/http/tests/test_status_generator.py @@ -1,5 +1,3 @@ -# flake8: noqa: W291,W293 - import socket import textwrap from datetime import datetime, timezone @@ -7,126 +5,136 @@ from irrd import __version__ from irrd.conf import get_setting + from ..status_generator import StatusGenerator class TestStatusGenerator: - def test_request(self, monkeypatch, config_override): mock_database_handler = Mock() - monkeypatch.setattr('irrd.server.http.status_generator.DatabaseHandler', lambda: mock_database_handler) + monkeypatch.setattr( + "irrd.server.http.status_generator.DatabaseHandler", lambda: mock_database_handler + ) mock_status_query = Mock() - monkeypatch.setattr('irrd.server.http.status_generator.DatabaseStatusQuery', lambda: mock_status_query) - monkeypatch.setattr('irrd.server.http.status_generator.is_serial_synchronised', - lambda dh, source: False) + monkeypatch.setattr( + "irrd.server.http.status_generator.DatabaseStatusQuery", lambda: mock_status_query + ) + monkeypatch.setattr( + "irrd.server.http.status_generator.is_serial_synchronised", lambda dh, source: False + ) mock_statistics_query = Mock() - monkeypatch.setattr('irrd.server.http.status_generator.RPSLDatabaseObjectStatisticsQuery', - lambda: mock_statistics_query) + monkeypatch.setattr( + "irrd.server.http.status_generator.RPSLDatabaseObjectStatisticsQuery", + lambda: mock_statistics_query, + ) def mock_whois_query(nrtm_host, nrtm_port, source): - assert source in ['TEST1', 'TEST2', 'TEST3'] - if source == 'TEST1': - assert nrtm_host == 'nrtm1.example.com' + assert source in ["TEST1", "TEST2", "TEST3"] + if source == "TEST1": + assert nrtm_host == "nrtm1.example.com" assert nrtm_port == 43 return True, 142, 143, 144 - elif source == 'TEST2': + elif source == "TEST2": raise ValueError() - elif source == 'TEST3': + elif source == "TEST3": raise socket.timeout() - monkeypatch.setattr('irrd.server.http.status_generator.whois_query_source_status', mock_whois_query) + monkeypatch.setattr("irrd.server.http.status_generator.whois_query_source_status", mock_whois_query) - config_override({ - 'sources': { - 'rpki': { - 'roa_source': 'roa source' - }, - 'TEST1': { - 'authoritative': False, - 'keep_journal': True, - 'nrtm_host': 'nrtm1.example.com', - 'nrtm_port': 43, - 'object_class_filter': 'object-class-filter', - 'rpki_excluded': True, - 'route_object_preference': 200, - }, - 'TEST2': { - 'authoritative': True, - 'keep_journal': False, - 'nrtm_host': 'nrtm2.example.com', - 'nrtm_port': 44, - }, - 'TEST3': { - 'authoritative': True, - 'keep_journal': False, - 'nrtm_host': 'nrtm3.example.com', - 'nrtm_port': 45, - }, - 'TEST4': { - 'authoritative': False, - 'keep_journal': False, - }, + config_override( + { + "sources": { + "rpki": {"roa_source": "roa source"}, + "TEST1": { + "authoritative": False, + "keep_journal": True, + "nrtm_host": "nrtm1.example.com", + "nrtm_port": 43, + "object_class_filter": "object-class-filter", + "rpki_excluded": True, + "route_object_preference": 200, + }, + "TEST2": { + "authoritative": True, + "keep_journal": False, + "nrtm_host": "nrtm2.example.com", + "nrtm_port": 44, + }, + "TEST3": { + "authoritative": True, + "keep_journal": False, + "nrtm_host": "nrtm3.example.com", + "nrtm_port": 45, + }, + "TEST4": { + "authoritative": False, + "keep_journal": False, + }, + } } - }) + ) - mock_query_result = iter([ - [ - {'source': 'TEST1', 'object_class': 'route', 'count': 10}, - {'source': 'TEST1', 'object_class': 'aut-num', 'count': 10}, - {'source': 'TEST1', 'object_class': 'other', 'count': 5}, - {'source': 'TEST2', 'object_class': 'route', 'count': 42}, - ], + mock_query_result = iter( [ - { - 'source': 'TEST1', - 'serial_oldest_seen': 10, - 'serial_newest_seen': 21, - 'serial_oldest_journal': 15, - 'serial_newest_journal': 20, - 'serial_last_export': 16, - 'serial_newest_mirror': 25, - 'last_error_timestamp': datetime(2018, 1, 1, tzinfo=timezone.utc), - 'updated': datetime(2018, 6, 1, tzinfo=timezone.utc), - }, - { - 'source': 'TEST2', - 'serial_oldest_seen': 210, - 'serial_newest_seen': 221, - 'serial_oldest_journal': None, - 'serial_newest_journal': None, - 'serial_last_export': None, - 'serial_newest_mirror': None, - 'last_error_timestamp': datetime(2019, 1, 1, tzinfo=timezone.utc), - 'updated': datetime(2019, 6, 1, tzinfo=timezone.utc), - }, - { - 'source': 'TEST3', - 'serial_oldest_seen': None, - 'serial_newest_seen': None, - 'serial_oldest_journal': None, - 'serial_newest_journal': None, - 'serial_last_export': None, - 'serial_newest_mirror': None, - 'last_error_timestamp': None, - 'updated': None, - }, - { - 'source': 'TEST4', - 'serial_oldest_seen': None, - 'serial_newest_seen': None, - 'serial_oldest_journal': None, - 'serial_newest_journal': None, - 'serial_last_export': None, - 'serial_newest_mirror': None, - 'last_error_timestamp': None, - 'updated': None, - }, - ], - ]) + [ + {"source": "TEST1", "object_class": "route", "count": 10}, + {"source": "TEST1", "object_class": "aut-num", "count": 10}, + {"source": "TEST1", "object_class": "other", "count": 5}, + {"source": "TEST2", "object_class": "route", "count": 42}, + ], + [ + { + "source": "TEST1", + "serial_oldest_seen": 10, + "serial_newest_seen": 21, + "serial_oldest_journal": 15, + "serial_newest_journal": 20, + "serial_last_export": 16, + "serial_newest_mirror": 25, + "last_error_timestamp": datetime(2018, 1, 1, tzinfo=timezone.utc), + "updated": datetime(2018, 6, 1, tzinfo=timezone.utc), + }, + { + "source": "TEST2", + "serial_oldest_seen": 210, + "serial_newest_seen": 221, + "serial_oldest_journal": None, + "serial_newest_journal": None, + "serial_last_export": None, + "serial_newest_mirror": None, + "last_error_timestamp": datetime(2019, 1, 1, tzinfo=timezone.utc), + "updated": datetime(2019, 6, 1, tzinfo=timezone.utc), + }, + { + "source": "TEST3", + "serial_oldest_seen": None, + "serial_newest_seen": None, + "serial_oldest_journal": None, + "serial_newest_journal": None, + "serial_last_export": None, + "serial_newest_mirror": None, + "last_error_timestamp": None, + "updated": None, + }, + { + "source": "TEST4", + "serial_oldest_seen": None, + "serial_newest_seen": None, + "serial_oldest_journal": None, + "serial_newest_journal": None, + "serial_last_export": None, + "serial_newest_mirror": None, + "last_error_timestamp": None, + "updated": None, + }, + ], + ] + ) mock_database_handler.execute_query = lambda query, flush_rpsl_buffer=True: next(mock_query_result) status_report = StatusGenerator().generate_status() - expected_report = textwrap.dedent(f""" + expected_report = textwrap.dedent( + f""" IRRD version {__version__} Listening on ::0 port {get_setting('server.whois.port')} @@ -236,6 +244,7 @@ def mock_whois_query(nrtm_host, nrtm_port, source): Route object preference: None Remote information: - No NRTM host configured.\n\n""").lstrip() + No NRTM host configured.\n\n""" + ).lstrip() assert expected_report == status_report diff --git a/irrd/server/query_resolver.py b/irrd/server/query_resolver.py index 99fe78f40..5351cd562 100644 --- a/irrd/server/query_resolver.py +++ b/irrd/server/query_resolver.py @@ -1,20 +1,23 @@ import logging from collections import OrderedDict from enum import Enum -from typing import Optional, List, Set, Tuple, Any, Dict +from typing import Any, Dict, List, Optional, Set, Tuple from IPy import IP from pytz import timezone -from irrd.conf import get_setting, RPKI_IRR_PSEUDO_SOURCE +from irrd.conf import RPKI_IRR_PSEUDO_SOURCE, get_setting from irrd.routepref.status import RoutePreferenceStatus from irrd.rpki.status import RPKIStatus -from irrd.rpsl.rpsl_objects import (OBJECT_CLASS_MAPPING, lookup_field_names) +from irrd.rpsl.rpsl_objects import OBJECT_CLASS_MAPPING, lookup_field_names from irrd.scopefilter.status import ScopeFilterStatus -from irrd.storage.database_handler import DatabaseHandler, is_serial_synchronised, \ - RPSLDatabaseResponse +from irrd.storage.database_handler import ( + DatabaseHandler, + RPSLDatabaseResponse, + is_serial_synchronised, +) from irrd.storage.preload import Preloader -from irrd.storage.queries import RPSLDatabaseQuery, DatabaseStatusQuery +from irrd.storage.queries import DatabaseStatusQuery, RPSLDatabaseQuery from irrd.utils.validators import parse_as_number logger = logging.getLogger(__name__) @@ -25,10 +28,10 @@ class InvalidQueryException(ValueError): class RouteLookupType(Enum): - EXACT = 'EXACT' - LESS_SPECIFIC_ONE_LEVEL = 'LESS_SPECIFIC_ONE_LEVEL' - LESS_SPECIFIC_WITH_EXACT = 'LESS_SPECIFIC_WITH_EXACT' - MORE_SPECIFIC_WITHOUT_EXACT = 'MORE_SPECIFIC_WITHOUT_EXACT' + EXACT = "EXACT" + LESS_SPECIFIC_ONE_LEVEL = "LESS_SPECIFIC_ONE_LEVEL" + LESS_SPECIFIC_WITH_EXACT = "LESS_SPECIFIC_WITH_EXACT" + MORE_SPECIFIC_WITHOUT_EXACT = "MORE_SPECIFIC_WITHOUT_EXACT" class QueryResolver: @@ -38,18 +41,19 @@ class QueryResolver: Some aspects like setting sources retain state, so a single instance should not be shared across unrelated query sessions. """ + lookup_field_names = lookup_field_names() database_handler: DatabaseHandler _current_set_root_object_class: Optional[str] def __init__(self, preloader: Preloader, database_handler: DatabaseHandler) -> None: - self.all_valid_sources = list(get_setting('sources', {}).keys()) - self.sources_default = list(get_setting('sources_default', [])) + self.all_valid_sources = list(get_setting("sources", {}).keys()) + self.sources_default = list(get_setting("sources_default", [])) self.sources: List[str] = self.sources_default if self.sources_default else self.all_valid_sources - if get_setting('rpki.roa_source'): + if get_setting("rpki.roa_source"): self.all_valid_sources.append(RPKI_IRR_PSEUDO_SOURCE) self.object_class_filter: List[str] = [] - self.rpki_aware = bool(get_setting('rpki.roa_source')) + self.rpki_aware = bool(get_setting("rpki.roa_source")) self.rpki_invalid_filter_enabled = self.rpki_aware self.out_scope_filter_enabled = True self.route_preference_filter_enabled = True @@ -64,7 +68,7 @@ def set_query_sources(self, sources: Optional[List[str]]) -> None: if sources is None: sources = self.sources_default if self.sources_default else self.all_valid_sources elif not all([source in self.all_valid_sources for source in sources]): - raise InvalidQueryException('One or more selected sources are unavailable.') + raise InvalidQueryException("One or more selected sources are unavailable.") self.sources = sources def disable_rpki_filter(self) -> None: @@ -91,7 +95,7 @@ def rpsl_text_search(self, value: str) -> RPSLDatabaseResponse: def route_search(self, address: IP, lookup_type: RouteLookupType): """Route(6) object search for an address, supporting exact/less/more specific.""" - query = self._prepare_query(ordered_by_sources=False).object_classes(['route', 'route6']) + query = self._prepare_query(ordered_by_sources=False).object_classes(["route", "route6"]) lookup_queries = { RouteLookupType.EXACT: query.ip_exact, RouteLookupType.LESS_SPECIFIC_ONE_LEVEL: query.ip_less_specific_one_level, @@ -108,14 +112,16 @@ def rpsl_attribute_search(self, attribute: str, value: str) -> RPSLDatabaseRespo as does `!oFOO`. Restricted to designated lookup fields. """ if attribute not in self.lookup_field_names: - readable_lookup_field_names = ', '.join(self.lookup_field_names) - msg = (f'Inverse attribute search not supported for {attribute},' + - f'only supported for attributes: {readable_lookup_field_names}') + readable_lookup_field_names = ", ".join(self.lookup_field_names) + msg = ( + f"Inverse attribute search not supported for {attribute}," + + f"only supported for attributes: {readable_lookup_field_names}" + ) raise InvalidQueryException(msg) query = self._prepare_query(ordered_by_sources=False).lookup_attr(attribute, value) return self._execute_query(query) - def routes_for_origin(self, origin: str, ip_version: Optional[int]=None) -> Set[str]: + def routes_for_origin(self, origin: str, ip_version: Optional[int] = None) -> Set[str]: """ Resolve all route(6)s prefixes for an origin, returning a set of all prefixes. Origin must be in 'ASxxx' format. @@ -123,28 +129,32 @@ def routes_for_origin(self, origin: str, ip_version: Optional[int]=None) -> Set[ prefixes = self.preloader.routes_for_origins([origin], self.sources, ip_version=ip_version) return prefixes - def routes_for_as_set(self, set_name: str, ip_version: Optional[int]=None, exclude_sets: Optional[Set[str]]=None) -> Set[str]: + def routes_for_as_set( + self, set_name: str, ip_version: Optional[int] = None, exclude_sets: Optional[Set[str]] = None + ) -> Set[str]: """ Find all originating prefixes for all members of an AS-set. May be restricted to IPv4 or IPv6. Returns a set of all prefixes. """ - self._current_set_root_object_class = 'as-set' + self._current_set_root_object_class = "as-set" self._current_excluded_sets = exclude_sets if exclude_sets else set() self._current_set_maximum_depth = 0 members = self._recursive_set_resolve({set_name}) return self.preloader.routes_for_origins(members, self.sources, ip_version=ip_version) - def members_for_set_per_source(self, parameter: str, exclude_sets: Optional[Set[str]]=None, depth=0, recursive=False) -> Dict[str, List[str]]: + def members_for_set_per_source( + self, parameter: str, exclude_sets: Optional[Set[str]] = None, depth=0, recursive=False + ) -> Dict[str, List[str]]: """ Find all members of an as-set or route-set, possibly recursively, distinguishing between multiple root objects in different sources with the same name. Returns a dict with sources as keys, list of all members, including leaf members, as values. """ - query = self._prepare_query(column_names=['source']) - object_classes = ['as-set', 'route-set'] + query = self._prepare_query(column_names=["source"]) + object_classes = ["as-set", "route-set"] query = query.object_classes(object_classes).rpsl_pk(parameter) - set_sources = [row['source'] for row in self._execute_query(query)] + set_sources = [row["source"] for row in self._execute_query(query)] return { source: self.members_for_set( @@ -157,7 +167,14 @@ def members_for_set_per_source(self, parameter: str, exclude_sets: Optional[Set[ for source in set_sources } - def members_for_set(self, parameter: str, exclude_sets: Optional[Set[str]]=None, depth=0, recursive=False, root_source: Optional[str]=None) -> List[str]: + def members_for_set( + self, + parameter: str, + exclude_sets: Optional[Set[str]] = None, + depth=0, + recursive=False, + root_source: Optional[str] = None, + ) -> List[str]: """ Find all members of an as-set or route-set, possibly recursively. Returns a list of all members, including leaf members. @@ -175,7 +192,7 @@ def members_for_set(self, parameter: str, exclude_sets: Optional[Set[str]]=None, if parameter in members: members.remove(parameter) - if get_setting('compatibility.ipv4_only_route_set_members'): + if get_setting("compatibility.ipv4_only_route_set_members"): original_members = set(members) for member in original_members: try: @@ -191,7 +208,9 @@ def members_for_set(self, parameter: str, exclude_sets: Optional[Set[str]]=None, return sorted(members) - def _recursive_set_resolve(self, members: Set[str], sets_seen=None, root_source: Optional[str]=None) -> Set[str]: + def _recursive_set_resolve( + self, members: Set[str], sets_seen=None, root_source: Optional[str] = None + ) -> Set[str]: """ Resolve all members of a number of sets, recursively. @@ -214,9 +233,12 @@ def _recursive_set_resolve(self, members: Set[str], sets_seen=None, root_source: sub_members, leaf_members = self._find_set_members(members, limit_source=root_source) for sub_member in sub_members: - if self._current_set_root_object_class is None or self._current_set_root_object_class == 'route-set': + if ( + self._current_set_root_object_class is None + or self._current_set_root_object_class == "route-set" + ): try: - IP(sub_member.split('^')[0]) + IP(sub_member.split("^")[0]) set_members.add(sub_member) continue except ValueError: @@ -226,9 +248,8 @@ def _recursive_set_resolve(self, members: Set[str], sets_seen=None, root_source: # the prefixes originating from that AS should be added to the response. try: as_number_formatted, _ = parse_as_number(sub_member) - if self._current_set_root_object_class == 'route-set': - set_members.update(self.preloader.routes_for_origins( - [as_number_formatted], self.sources)) + if self._current_set_root_object_class == "route-set": + set_members.update(self.preloader.routes_for_origins([as_number_formatted], self.sources)) resolved_as_members.add(sub_member) else: set_members.add(sub_member) @@ -240,13 +261,17 @@ def _recursive_set_resolve(self, members: Set[str], sets_seen=None, root_source: if self._current_set_maximum_depth == 0: return set_members | sub_members | leaf_members - further_resolving_required = sub_members - set_members - sets_seen - resolved_as_members - self._current_excluded_sets + further_resolving_required = ( + sub_members - set_members - sets_seen - resolved_as_members - self._current_excluded_sets + ) new_members = self._recursive_set_resolve(further_resolving_required, sets_seen) set_members.update(new_members) return set_members - def _find_set_members(self, set_names: Set[str], limit_source: Optional[str]=None) -> Tuple[Set[str], Set[str]]: + def _find_set_members( + self, set_names: Set[str], limit_source: Optional[str] = None + ) -> Tuple[Set[str], Set[str]]: """ Find all members of a number of route-sets or as-sets. Includes both direct members listed in members attribute, but also @@ -263,13 +288,13 @@ def _find_set_members(self, set_names: Set[str], limit_source: Optional[str]=Non members: Set[str] = set() sets_already_resolved: Set[str] = set() - columns = ['parsed_data', 'rpsl_pk', 'source', 'object_class'] + columns = ["parsed_data", "rpsl_pk", "source", "object_class"] query = self._prepare_query(column_names=columns) - object_classes = ['as-set', 'route-set'] + object_classes = ["as-set", "route-set"] # Per RFC 2622 5.3, route-sets can refer to as-sets, # but as-sets can only refer to other as-sets. - if self._current_set_root_object_class == 'as-set': + if self._current_set_root_object_class == "as-set": object_classes = [self._current_set_root_object_class] query = query.object_classes(object_classes).rpsl_pks(set_names) @@ -286,10 +311,10 @@ def _find_set_members(self, set_names: Set[str], limit_source: Optional[str]=Non # on the first run: when the set resolving should be fixed to one # type of set object. if not self._current_set_root_object_class: - self._current_set_root_object_class = query_result[0]['object_class'] + self._current_set_root_object_class = query_result[0]["object_class"] for result in query_result: - rpsl_pk = result['rpsl_pk'] + rpsl_pk = result["rpsl_pk"] # The same PK may occur in multiple sources, but we are # only interested in the first matching object, prioritised @@ -299,10 +324,10 @@ def _find_set_members(self, set_names: Set[str], limit_source: Optional[str]=Non continue sets_already_resolved.add(rpsl_pk) - object_class = result['object_class'] - object_data = result['parsed_data'] - mbrs_by_ref = object_data.get('mbrs-by-ref', None) - for members_attr in ['members', 'mp-members']: + object_class = result["object_class"] + object_data = result["parsed_data"] + mbrs_by_ref = object_data.get("mbrs-by-ref", None) + for members_attr in ["members", "mp-members"]: if members_attr in object_data: members.update(set(object_data[members_attr])) @@ -312,23 +337,25 @@ def _find_set_members(self, set_names: Set[str], limit_source: Optional[str]=Non # If mbrs-by-ref is set, find any objects with member-of pointing to the route/as-set # under query, and include a maintainer listed in mbrs-by-ref, unless mbrs-by-ref # is set to ANY. - query_object_class = ['route', 'route6'] if object_class == 'route-set' else ['aut-num'] + query_object_class = ["route", "route6"] if object_class == "route-set" else ["aut-num"] query = self._prepare_query(column_names=columns).object_classes(query_object_class) - query = query.lookup_attrs_in(['member-of'], [rpsl_pk]) + query = query.lookup_attrs_in(["member-of"], [rpsl_pk]) - if 'ANY' not in [m.strip().upper() for m in mbrs_by_ref]: - query = query.lookup_attrs_in(['mnt-by'], mbrs_by_ref) + if "ANY" not in [m.strip().upper() for m in mbrs_by_ref]: + query = query.lookup_attrs_in(["mnt-by"], mbrs_by_ref) referring_objects = self._execute_query(query) for result in referring_objects: - member_object_class = result['object_class'] - members.add(result['parsed_data'][member_object_class]) + member_object_class = result["object_class"] + members.add(result["parsed_data"][member_object_class]) leaf_members = set_names - sets_already_resolved return members, leaf_members - def database_status(self, sources: Optional[List[str]]=None) -> 'OrderedDict[str, OrderedDict[str, Any]]': + def database_status( + self, sources: Optional[List[str]] = None + ) -> "OrderedDict[str, OrderedDict[str, Any]]": """Database status. If sources is None, return all valid sources.""" if sources is None: sources = self.sources_default if self.sources_default else self.all_valid_sources @@ -338,24 +365,30 @@ def database_status(self, sources: Optional[List[str]]=None) -> 'OrderedDict[str results: OrderedDict[str, OrderedDict[str, Any]] = OrderedDict() for query_result in query_results: - source = query_result['source'].upper() + source = query_result["source"].upper() results[source] = OrderedDict() - results[source]['authoritative'] = get_setting(f'sources.{source}.authoritative', False) - object_class_filter = get_setting(f'sources.{source}.object_class_filter') - results[source]['object_class_filter'] = list(object_class_filter) if object_class_filter else None - results[source]['rpki_rov_filter'] = bool(get_setting('rpki.roa_source') and not get_setting(f'sources.{source}.rpki_excluded')) - results[source]['scopefilter_enabled'] = bool(get_setting('scopefilter')) and not get_setting(f'sources.{source}.scopefilter_excluded') - results[source]['route_preference'] = get_setting(f'sources.{source}.route_object_preference') - results[source]['local_journal_kept'] = get_setting(f'sources.{source}.keep_journal', False) - results[source]['serial_oldest_journal'] = query_result['serial_oldest_journal'] - results[source]['serial_newest_journal'] = query_result['serial_newest_journal'] - results[source]['serial_last_export'] = query_result['serial_last_export'] - results[source]['serial_newest_mirror'] = query_result['serial_newest_mirror'] - results[source]['last_update'] = query_result['updated'].astimezone(timezone('UTC')).isoformat() - results[source]['synchronised_serials'] = is_serial_synchronised(self.database_handler, source) + results[source]["authoritative"] = get_setting(f"sources.{source}.authoritative", False) + object_class_filter = get_setting(f"sources.{source}.object_class_filter") + results[source]["object_class_filter"] = ( + list(object_class_filter) if object_class_filter else None + ) + results[source]["rpki_rov_filter"] = bool( + get_setting("rpki.roa_source") and not get_setting(f"sources.{source}.rpki_excluded") + ) + results[source]["scopefilter_enabled"] = bool(get_setting("scopefilter")) and not get_setting( + f"sources.{source}.scopefilter_excluded" + ) + results[source]["route_preference"] = get_setting(f"sources.{source}.route_object_preference") + results[source]["local_journal_kept"] = get_setting(f"sources.{source}.keep_journal", False) + results[source]["serial_oldest_journal"] = query_result["serial_oldest_journal"] + results[source]["serial_newest_journal"] = query_result["serial_newest_journal"] + results[source]["serial_last_export"] = query_result["serial_last_export"] + results[source]["serial_newest_mirror"] = query_result["serial_newest_mirror"] + results[source]["last_update"] = query_result["updated"].astimezone(timezone("UTC")).isoformat() + results[source]["synchronised_serials"] = is_serial_synchronised(self.database_handler, source) for invalid_source in invalid_sources: - results[invalid_source.upper()] = OrderedDict({'error': 'Unknown source'}) + results[invalid_source.upper()] = OrderedDict({"error": "Unknown source"}) return results def rpsl_object_template(self, object_class) -> str: @@ -363,7 +396,7 @@ def rpsl_object_template(self, object_class) -> str: try: return OBJECT_CLASS_MAPPING[object_class]().generate_template() except KeyError: - raise InvalidQueryException(f'Unknown object class: {object_class}') + raise InvalidQueryException(f"Unknown object class: {object_class}") def enable_sql_trace(self): self.sql_trace = True diff --git a/irrd/server/test_access_check.py b/irrd/server/test_access_check.py index f1609a4ba..0978d549b 100644 --- a/irrd/server/test_access_check.py +++ b/irrd/server/test_access_check.py @@ -2,45 +2,51 @@ class TestIsClientPermitted: - client_ip = '192.0.2.1' + client_ip = "192.0.2.1" def test_no_access_list(self): - assert is_client_permitted(self.client_ip, 'server.whois.access_list', default_deny=False) - assert not is_client_permitted(self.client_ip, 'server.whois.access_list', default_deny=True) + assert is_client_permitted(self.client_ip, "server.whois.access_list", default_deny=False) + assert not is_client_permitted(self.client_ip, "server.whois.access_list", default_deny=True) def test_access_list_permitted(self, config_override): - config_override({ - 'server': { - 'whois': { - 'access_list': 'test-access-list', + config_override( + { + "server": { + "whois": { + "access_list": "test-access-list", + }, }, - }, - 'access_lists': { - 'test-access-list': ['192.0.2.0/25', '2001:db8::/32'], - }, - }) + "access_lists": { + "test-access-list": ["192.0.2.0/25", "2001:db8::/32"], + }, + } + ) - assert is_client_permitted(self.client_ip, 'server.whois.access_list', default_deny=False) - assert is_client_permitted(self.client_ip, 'server.whois.access_list', default_deny=True) - assert is_client_permitted(f'::ffff:{self.client_ip}', 'server.whois.access_list', default_deny=True) - assert is_client_permitted('2001:db8::1', 'server.whois.access_list', default_deny=True) + assert is_client_permitted(self.client_ip, "server.whois.access_list", default_deny=False) + assert is_client_permitted(self.client_ip, "server.whois.access_list", default_deny=True) + assert is_client_permitted(f"::ffff:{self.client_ip}", "server.whois.access_list", default_deny=True) + assert is_client_permitted("2001:db8::1", "server.whois.access_list", default_deny=True) def test_access_list_denied(self, config_override): - config_override({ - 'server': { - 'whois': { - 'access_list': 'test-access-list', + config_override( + { + "server": { + "whois": { + "access_list": "test-access-list", + }, + }, + "access_lists": { + "test-access-list": ["192.0.2.128/25", "2001:db8::/32"], }, - }, - 'access_lists': { - 'test-access-list': ['192.0.2.128/25', '2001:db8::/32'], - }, - }) + } + ) - assert not is_client_permitted(self.client_ip, 'server.whois.access_list', default_deny=False) - assert not is_client_permitted(f'::ffff:{self.client_ip}', 'server.whois.access_list', default_deny=False) - assert not is_client_permitted(self.client_ip, 'server.whois.access_list', default_deny=True) + assert not is_client_permitted(self.client_ip, "server.whois.access_list", default_deny=False) + assert not is_client_permitted( + f"::ffff:{self.client_ip}", "server.whois.access_list", default_deny=False + ) + assert not is_client_permitted(self.client_ip, "server.whois.access_list", default_deny=True) def test_access_list_denied_invalid_ip(self): - assert not is_client_permitted('invalid', 'server.whois.access_list', default_deny=False) - assert not is_client_permitted('invalid', 'server.whois.access_list', default_deny=True) + assert not is_client_permitted("invalid", "server.whois.access_list", default_deny=False) + assert not is_client_permitted("invalid", "server.whois.access_list", default_deny=True) diff --git a/irrd/server/tests/test_query_resolver.py b/irrd/server/tests/test_query_resolver.py index 6763a5f29..2c7d1baa7 100644 --- a/irrd/server/tests/test_query_resolver.py +++ b/irrd/server/tests/test_query_resolver.py @@ -7,12 +7,13 @@ from IPy import IP from pytz import timezone -from irrd.rpki.status import RPKIStatus from irrd.routepref.status import RoutePreferenceStatus +from irrd.rpki.status import RPKIStatus from irrd.scopefilter.status import ScopeFilterStatus from irrd.storage.preload import Preloader from irrd.utils.test_utils import flatten_mock_calls -from ..query_resolver import QueryResolver, RouteLookupType, InvalidQueryException + +from ..query_resolver import InvalidQueryException, QueryResolver, RouteLookupType # Note that these mock objects are not entirely valid RPSL objects, # as they are meant to test all the scenarios in the query resolver. @@ -38,21 +39,25 @@ source: TEST2 """ -MOCK_ROUTE_COMBINED = MOCK_ROUTE1 + '\n' + MOCK_ROUTE2 + '\n' + MOCK_ROUTE3.strip() +MOCK_ROUTE_COMBINED = MOCK_ROUTE1 + "\n" + MOCK_ROUTE2 + "\n" + MOCK_ROUTE3.strip() @pytest.fixture() def prepare_resolver(monkeypatch, config_override): - config_override({ - 'rpki': {'roa_source': None}, - 'sources': {'TEST1': {}, 'TEST2': {}}, - 'sources_default': [], - }) + config_override( + { + "rpki": {"roa_source": None}, + "sources": {"TEST1": {}, "TEST2": {}}, + "sources_default": [], + } + ) mock_database_handler = Mock() mock_database_query = Mock() - monkeypatch.setattr('irrd.server.query_resolver.RPSLDatabaseQuery', - lambda columns=None, ordered_by_sources=True: mock_database_query) + monkeypatch.setattr( + "irrd.server.query_resolver.RPSLDatabaseQuery", + lambda columns=None, ordered_by_sources=True: mock_database_query, + ) mock_preloader = Mock(spec=Preloader) resolver = QueryResolver(mock_preloader, mock_database_handler) @@ -61,38 +66,47 @@ def prepare_resolver(monkeypatch, config_override): mock_query_result = [ { - 'pk': uuid.uuid4(), - 'rpsl_pk': '192.0.2.0/25,AS65547', - 'object_class': 'route', - 'parsed_data': { - 'route': '192.0.2.0/25', 'origin': 'AS65547', 'mnt-by': 'MNT-TEST', - 'source': 'TEST1', - 'members': ['AS1, AS2'] + "pk": uuid.uuid4(), + "rpsl_pk": "192.0.2.0/25,AS65547", + "object_class": "route", + "parsed_data": { + "route": "192.0.2.0/25", + "origin": "AS65547", + "mnt-by": "MNT-TEST", + "source": "TEST1", + "members": ["AS1, AS2"], }, - 'object_text': MOCK_ROUTE1, - 'rpki_status': RPKIStatus.not_found, - 'source': 'TEST1', + "object_text": MOCK_ROUTE1, + "rpki_status": RPKIStatus.not_found, + "source": "TEST1", }, { - 'pk': uuid.uuid4(), - - 'rpsl_pk': '192.0.2.0/25,AS65544', - 'object_class': 'route', - 'parsed_data': {'route': '192.0.2.0/25', 'origin': 'AS65544', 'mnt-by': 'MNT-TEST', - 'source': 'TEST2'}, - 'object_text': MOCK_ROUTE2, - 'rpki_status': RPKIStatus.valid, - 'source': 'TEST2', + "pk": uuid.uuid4(), + "rpsl_pk": "192.0.2.0/25,AS65544", + "object_class": "route", + "parsed_data": { + "route": "192.0.2.0/25", + "origin": "AS65544", + "mnt-by": "MNT-TEST", + "source": "TEST2", + }, + "object_text": MOCK_ROUTE2, + "rpki_status": RPKIStatus.valid, + "source": "TEST2", }, { - 'pk': uuid.uuid4(), - 'rpsl_pk': '192.0.2.128/25,AS65545', - 'object_class': 'route', - 'parsed_data': {'route': '192.0.2.128/25', 'origin': 'AS65545', 'mnt-by': 'MNT-TEST', - 'source': 'TEST2'}, - 'object_text': MOCK_ROUTE3, - 'rpki_status': RPKIStatus.valid, - 'source': 'TEST2', + "pk": uuid.uuid4(), + "rpsl_pk": "192.0.2.128/25,AS65545", + "object_class": "route", + "parsed_data": { + "route": "192.0.2.128/25", + "origin": "AS65545", + "mnt-by": "MNT-TEST", + "source": "TEST2", + }, + "object_text": MOCK_ROUTE3, + "rpki_status": RPKIStatus.valid, + "source": "TEST2", }, ] mock_database_handler.execute_query = lambda query, refresh_on_error=False: mock_query_result @@ -107,62 +121,64 @@ def test_set_sources(self, prepare_resolver): resolver.set_query_sources(None) assert resolver.sources == resolver.all_valid_sources - resolver.set_query_sources(['TEST1']) - assert resolver.sources == ['TEST1'] + resolver.set_query_sources(["TEST1"]) + assert resolver.sources == ["TEST1"] # With RPKI-aware mode disabled, RPKI is not a valid source with pytest.raises(InvalidQueryException): - resolver.set_query_sources(['RPKI']) + resolver.set_query_sources(["RPKI"]) def test_default_sources(self, prepare_resolver, config_override): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver mock_dh.reset_mock() - config_override({ - 'sources': {'TEST1': {}, 'TEST2': {}}, - 'sources_default': ['TEST2', 'TEST1'], - }) + config_override( + { + "sources": {"TEST1": {}, "TEST2": {}}, + "sources_default": ["TEST2", "TEST1"], + } + ) resolver = QueryResolver(mock_preloader, mock_dh) - assert list(resolver.sources_default) == ['TEST2', 'TEST1'] + assert list(resolver.sources_default) == ["TEST2", "TEST1"] def test_restrict_object_class(self, prepare_resolver): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver mock_dh.reset_mock() - resolver.set_object_class_filter_next_query(['route']) - result = resolver.rpsl_attribute_search('mnt-by', 'MNT-TEST') + resolver.set_object_class_filter_next_query(["route"]) + result = resolver.rpsl_attribute_search("mnt-by", "MNT-TEST") assert list(result) == mock_query_result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['route'],), {}], - ['lookup_attr', ('mnt-by', 'MNT-TEST'), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["route"],), {}], + ["lookup_attr", ("mnt-by", "MNT-TEST"), {}], ] mock_dq.reset_mock() # filter should not persist - result = resolver.rpsl_attribute_search('mnt-by', 'MNT-TEST') + result = resolver.rpsl_attribute_search("mnt-by", "MNT-TEST") assert list(result) == mock_query_result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['lookup_attr', ('mnt-by', 'MNT-TEST'), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["lookup_attr", ("mnt-by", "MNT-TEST"), {}], ] def test_key_lookup(self, prepare_resolver): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver - result = resolver.key_lookup('route', '192.0.2.0/25') + result = resolver.key_lookup("route", "192.0.2.0/25") assert list(result) == mock_query_result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['route'],), {}], - ['rpsl_pk', ('192.0.2.0/25',), {}], - ['first_only', (), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["route"],), {}], + ["rpsl_pk", ("192.0.2.0/25",), {}], + ["first_only", (), {}], ] def test_key_lookup_with_sql_trace(self, prepare_resolver): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver resolver.enable_sql_trace() - result = resolver.key_lookup('route', '192.0.2.0/25') + result = resolver.key_lookup("route", "192.0.2.0/25") assert list(result) == mock_query_result assert len(resolver.retrieve_sql_trace()) == 1 assert len(resolver.retrieve_sql_trace()) == 0 @@ -170,127 +186,129 @@ def test_key_lookup_with_sql_trace(self, prepare_resolver): def test_limit_sources_key_lookup(self, prepare_resolver): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver - resolver.set_query_sources(['TEST1']) - result = resolver.key_lookup('route', '192.0.2.0/25') + resolver.set_query_sources(["TEST1"]) + result = resolver.key_lookup("route", "192.0.2.0/25") assert list(result) == mock_query_result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1'],), {}], - ['object_classes', (['route'],), {}], - ['rpsl_pk', ('192.0.2.0/25',), {}], - ['first_only', (), {}], + ["sources", (["TEST1"],), {}], + ["object_classes", (["route"],), {}], + ["rpsl_pk", ("192.0.2.0/25",), {}], + ["first_only", (), {}], ] def test_text_search(self, prepare_resolver): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver mock_dh.reset_mock() - result = resolver.rpsl_text_search('query') + result = resolver.rpsl_text_search("query") assert list(result) == mock_query_result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['text_search', ('query',), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["text_search", ("query",), {}], ] def test_route_search_exact(self, prepare_resolver): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver - result = resolver.route_search(IP('192.0.2.0/25'), RouteLookupType.EXACT) + result = resolver.route_search(IP("192.0.2.0/25"), RouteLookupType.EXACT) assert list(result) == mock_query_result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['route', 'route6'],), {}], - ['ip_exact', (IP('192.0.2.0/25'),), {}] + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["route", "route6"],), {}], + ["ip_exact", (IP("192.0.2.0/25"),), {}], ] mock_dq.reset_mock() def test_route_search_less_specific_one_level(self, prepare_resolver): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver - result = resolver.route_search(IP('192.0.2.0/25'), RouteLookupType.LESS_SPECIFIC_ONE_LEVEL) + result = resolver.route_search(IP("192.0.2.0/25"), RouteLookupType.LESS_SPECIFIC_ONE_LEVEL) assert list(result) == mock_query_result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['route', 'route6'],), {}], - ['ip_less_specific_one_level', (IP('192.0.2.0/25'),), {}] + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["route", "route6"],), {}], + ["ip_less_specific_one_level", (IP("192.0.2.0/25"),), {}], ] def test_route_search_less_specific(self, prepare_resolver): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver - result = resolver.route_search(IP('192.0.2.0/25'), RouteLookupType.LESS_SPECIFIC_WITH_EXACT) + result = resolver.route_search(IP("192.0.2.0/25"), RouteLookupType.LESS_SPECIFIC_WITH_EXACT) assert list(result) == mock_query_result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['route', 'route6'],), {}], - ['ip_less_specific', (IP('192.0.2.0/25'),), {}] + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["route", "route6"],), {}], + ["ip_less_specific", (IP("192.0.2.0/25"),), {}], ] def test_route_search_more_specific(self, prepare_resolver): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver - result = resolver.route_search(IP('192.0.2.0/25'), RouteLookupType.MORE_SPECIFIC_WITHOUT_EXACT) + result = resolver.route_search(IP("192.0.2.0/25"), RouteLookupType.MORE_SPECIFIC_WITHOUT_EXACT) assert list(result) == mock_query_result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['route', 'route6'],), {}], - ['ip_more_specific', (IP('192.0.2.0/25'),), {}] + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["route", "route6"],), {}], + ["ip_more_specific", (IP("192.0.2.0/25"),), {}], ] def test_route_search_exact_rpki_aware(self, prepare_resolver, config_override): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver - config_override({ - 'sources': {'TEST1': {}, 'TEST2': {}}, - 'sources_default': [], - 'rpki': {'roa_source': 'https://example.com/roa.json'}, - }) + config_override( + { + "sources": {"TEST1": {}, "TEST2": {}}, + "sources_default": [], + "rpki": {"roa_source": "https://example.com/roa.json"}, + } + ) resolver = QueryResolver(mock_preloader, mock_dh) resolver.out_scope_filter_enabled = False resolver.route_preference_filter_enabled = False - result = resolver.route_search(IP('192.0.2.0/25'), RouteLookupType.EXACT) + result = resolver.route_search(IP("192.0.2.0/25"), RouteLookupType.EXACT) assert list(result) == mock_query_result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2', 'RPKI'],), {}], - ['rpki_status', ([RPKIStatus.not_found, RPKIStatus.valid],), {}], - ['object_classes', (['route', 'route6'],), {}], - ['ip_exact', (IP('192.0.2.0/25'),), {}] + ["sources", (["TEST1", "TEST2", "RPKI"],), {}], + ["rpki_status", ([RPKIStatus.not_found, RPKIStatus.valid],), {}], + ["object_classes", (["route", "route6"],), {}], + ["ip_exact", (IP("192.0.2.0/25"),), {}], ] mock_dq.reset_mock() resolver.disable_rpki_filter() - result = resolver.route_search(IP('192.0.2.0/25'), RouteLookupType.EXACT) + result = resolver.route_search(IP("192.0.2.0/25"), RouteLookupType.EXACT) assert list(result) == mock_query_result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2', 'RPKI'],), {}], - ['object_classes', (['route', 'route6'],), {}], - ['ip_exact', (IP('192.0.2.0/25'),), {}] + ["sources", (["TEST1", "TEST2", "RPKI"],), {}], + ["object_classes", (["route", "route6"],), {}], + ["ip_exact", (IP("192.0.2.0/25"),), {}], ] mock_dq.reset_mock() - resolver.set_query_sources(['RPKI']) - assert resolver.sources == ['RPKI'] + resolver.set_query_sources(["RPKI"]) + assert resolver.sources == ["RPKI"] def test_route_search_exact_with_scopefilter(self, prepare_resolver, config_override): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver resolver.out_scope_filter_enabled = True - result = resolver.route_search(IP('192.0.2.0/25'), RouteLookupType.EXACT) + result = resolver.route_search(IP("192.0.2.0/25"), RouteLookupType.EXACT) assert list(result) == mock_query_result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['scopefilter_status', ([ScopeFilterStatus.in_scope],), {}], - ['object_classes', (['route', 'route6'],), {}], - ['ip_exact', (IP('192.0.2.0/25'),), {}] + ["sources", (["TEST1", "TEST2"],), {}], + ["scopefilter_status", ([ScopeFilterStatus.in_scope],), {}], + ["object_classes", (["route", "route6"],), {}], + ["ip_exact", (IP("192.0.2.0/25"),), {}], ] mock_dq.reset_mock() resolver.disable_out_of_scope_filter() - result = resolver.route_search(IP('192.0.2.0/25'), RouteLookupType.EXACT) + result = resolver.route_search(IP("192.0.2.0/25"), RouteLookupType.EXACT) assert list(result) == mock_query_result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['route', 'route6'],), {}], - ['ip_exact', (IP('192.0.2.0/25'),), {}] + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["route", "route6"],), {}], + ["ip_exact", (IP("192.0.2.0/25"),), {}], ] mock_dq.reset_mock() @@ -298,53 +316,53 @@ def test_route_search_exact_with_route_preference_filter(self, prepare_resolver, mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver resolver.route_preference_filter_enabled = True - result = resolver.route_search(IP('192.0.2.0/25'), RouteLookupType.EXACT) + result = resolver.route_search(IP("192.0.2.0/25"), RouteLookupType.EXACT) assert list(result) == mock_query_result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['route_preference_status', ([RoutePreferenceStatus.visible],), {}], - ['object_classes', (['route', 'route6'],), {}], - ['ip_exact', (IP('192.0.2.0/25'),), {}] + ["sources", (["TEST1", "TEST2"],), {}], + ["route_preference_status", ([RoutePreferenceStatus.visible],), {}], + ["object_classes", (["route", "route6"],), {}], + ["ip_exact", (IP("192.0.2.0/25"),), {}], ] mock_dq.reset_mock() resolver.disable_route_preference_filter() - result = resolver.route_search(IP('192.0.2.0/25'), RouteLookupType.EXACT) + result = resolver.route_search(IP("192.0.2.0/25"), RouteLookupType.EXACT) assert list(result) == mock_query_result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['route', 'route6'],), {}], - ['ip_exact', (IP('192.0.2.0/25'),), {}] + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["route", "route6"],), {}], + ["ip_exact", (IP("192.0.2.0/25"),), {}], ] mock_dq.reset_mock() def test_rpsl_attribute_search(self, prepare_resolver): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver - result = resolver.rpsl_attribute_search('mnt-by', 'MNT-TEST') + result = resolver.rpsl_attribute_search("mnt-by", "MNT-TEST") assert list(result) == mock_query_result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['lookup_attr', ('mnt-by', 'MNT-TEST'), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["lookup_attr", ("mnt-by", "MNT-TEST"), {}], ] mock_dh.execute_query = lambda query, refresh_on_error=False: [] with pytest.raises(InvalidQueryException): - resolver.rpsl_attribute_search('invalid-attr', 'MNT-TEST') + resolver.rpsl_attribute_search("invalid-attr", "MNT-TEST") def test_routes_for_origin(self, prepare_resolver): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver - mock_preloader.routes_for_origins = Mock(return_value={'192.0.2.0/25', '192.0.2.128/25'}) + mock_preloader.routes_for_origins = Mock(return_value={"192.0.2.0/25", "192.0.2.128/25"}) - result = resolver.routes_for_origin('AS65547', 4) - assert result == {'192.0.2.0/25', '192.0.2.128/25'} + result = resolver.routes_for_origin("AS65547", 4) + assert result == {"192.0.2.0/25", "192.0.2.128/25"} assert flatten_mock_calls(mock_preloader.routes_for_origins) == [ - ['', (['AS65547'], ['TEST1', 'TEST2']), {'ip_version': 4}], + ["", (["AS65547"], ["TEST1", "TEST2"]), {"ip_version": 4}], ] mock_preloader.routes_for_origins = Mock(return_value={}) - result = resolver.routes_for_origin('AS65547', 4) + result = resolver.routes_for_origin("AS65547", 4) assert result == {} assert not mock_dq.mock_calls @@ -352,24 +370,24 @@ def test_routes_for_as_set(self, prepare_resolver, monkeypatch): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver monkeypatch.setattr( - 'irrd.server.query_resolver.QueryResolver._recursive_set_resolve', - lambda self, set_name: {'AS65547', 'AS65548'} + "irrd.server.query_resolver.QueryResolver._recursive_set_resolve", + lambda self, set_name: {"AS65547", "AS65548"}, ) mock_preloader.routes_for_origins = Mock(return_value=[]) - result = resolver.routes_for_as_set('AS65547', 4) + result = resolver.routes_for_as_set("AS65547", 4) assert flatten_mock_calls(mock_preloader.routes_for_origins) == [ - ['', ({'AS65547', 'AS65548'}, resolver.all_valid_sources), {'ip_version': 4}], + ["", ({"AS65547", "AS65548"}, resolver.all_valid_sources), {"ip_version": 4}], ] assert not result - mock_preloader.routes_for_origins = Mock(return_value={'192.0.2.0/25', '192.0.2.128/25'}) + mock_preloader.routes_for_origins = Mock(return_value={"192.0.2.0/25", "192.0.2.128/25"}) - result = resolver.routes_for_as_set('AS65547') - assert resolver._current_set_root_object_class == 'as-set' - assert result == {'192.0.2.0/25', '192.0.2.128/25'} + result = resolver.routes_for_as_set("AS65547") + assert resolver._current_set_root_object_class == "as-set" + assert result == {"192.0.2.0/25", "192.0.2.128/25"} assert flatten_mock_calls(mock_preloader.routes_for_origins) == [ - ['', ({'AS65547', 'AS65548'}, resolver.all_valid_sources), {'ip_version': None}], + ["", ({"AS65547", "AS65548"}, resolver.all_valid_sources), {"ip_version": None}], ] assert not mock_dq.mock_calls @@ -379,107 +397,105 @@ def test_as_set_members(self, prepare_resolver): mock_query_result1 = [ { - 'pk': uuid.uuid4(), - 'rpsl_pk': 'AS-FIRSTLEVEL', - 'parsed_data': {'as-set': 'AS-FIRSTLEVEL', - 'members': [ - 'AS65547', 'AS-FIRSTLEVEL', 'AS-SECONDLEVEL', 'AS-2nd-UNKNOWN' - ]}, - 'object_text': 'text', - 'object_class': 'as-set', - 'source': 'TEST1', + "pk": uuid.uuid4(), + "rpsl_pk": "AS-FIRSTLEVEL", + "parsed_data": { + "as-set": "AS-FIRSTLEVEL", + "members": ["AS65547", "AS-FIRSTLEVEL", "AS-SECONDLEVEL", "AS-2nd-UNKNOWN"], + }, + "object_text": "text", + "object_class": "as-set", + "source": "TEST1", }, ] mock_query_result2 = [ { - 'pk': uuid.uuid4(), - 'rpsl_pk': 'AS-SECONDLEVEL', - 'parsed_data': {'as-set': 'AS-SECONDLEVEL', - 'members': ['AS-THIRDLEVEL', 'AS65544']}, - 'object_text': 'text', - 'object_class': 'as-set', - 'source': 'TEST1', + "pk": uuid.uuid4(), + "rpsl_pk": "AS-SECONDLEVEL", + "parsed_data": {"as-set": "AS-SECONDLEVEL", "members": ["AS-THIRDLEVEL", "AS65544"]}, + "object_text": "text", + "object_class": "as-set", + "source": "TEST1", }, { # Should be ignored - only the first result per PK is accepted. - 'pk': uuid.uuid4(), - 'rpsl_pk': 'AS-SECONDLEVEL', - 'parsed_data': {'as-set': 'AS-SECONDLEVEL', 'members': ['AS-IGNOREME']}, - 'object_text': 'text', - 'object_class': 'as-set', - 'source': 'TEST2', + "pk": uuid.uuid4(), + "rpsl_pk": "AS-SECONDLEVEL", + "parsed_data": {"as-set": "AS-SECONDLEVEL", "members": ["AS-IGNOREME"]}, + "object_text": "text", + "object_class": "as-set", + "source": "TEST2", }, ] mock_query_result3 = [ { - 'pk': uuid.uuid4(), - 'rpsl_pk': 'AS-THIRDLEVEL', + "pk": uuid.uuid4(), + "rpsl_pk": "AS-THIRDLEVEL", # Refers back to the first as-set to test infinite recursion issues - 'parsed_data': {'as-set': 'AS-THIRDLEVEL', - 'members': ['AS65545', 'AS-FIRSTLEVEL', 'AS-4th-UNKNOWN']}, - 'object_text': 'text', - 'object_class': 'as-set', - 'source': 'TEST2', + "parsed_data": { + "as-set": "AS-THIRDLEVEL", + "members": ["AS65545", "AS-FIRSTLEVEL", "AS-4th-UNKNOWN"], + }, + "object_text": "text", + "object_class": "as-set", + "source": "TEST2", }, ] mock_dh.execute_query = lambda query, refresh_on_error=False: iter(mock_query_result1) - result = resolver.members_for_set('AS-FIRSTLEVEL', recursive=False) - assert result == ['AS-2nd-UNKNOWN', 'AS-SECONDLEVEL', 'AS65547'] + result = resolver.members_for_set("AS-FIRSTLEVEL", recursive=False) + assert result == ["AS-2nd-UNKNOWN", "AS-SECONDLEVEL", "AS65547"] assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['as-set', 'route-set'],), {}], - ['rpsl_pks', ({'AS-FIRSTLEVEL'},), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["as-set", "route-set"],), {}], + ["rpsl_pks", ({"AS-FIRSTLEVEL"},), {}], ] mock_dq.reset_mock() mock_query_iterator = iter( - [mock_query_result1, mock_query_result2, mock_query_result3, [], mock_query_result1, - []]) + [mock_query_result1, mock_query_result2, mock_query_result3, [], mock_query_result1, []] + ) mock_dh.execute_query = lambda query, refresh_on_error=False: iter(next(mock_query_iterator)) - result = resolver.members_for_set('AS-FIRSTLEVEL', recursive=True) - assert result == ['AS65544', 'AS65545', 'AS65547'] + result = resolver.members_for_set("AS-FIRSTLEVEL", recursive=True) + assert result == ["AS65544", "AS65545", "AS65547"] assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['as-set', 'route-set'],), {}], - ['rpsl_pks', ({'AS-FIRSTLEVEL'},), {}], - - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['as-set'],), {}], - ['rpsl_pks', ({'AS-2nd-UNKNOWN', 'AS-SECONDLEVEL'},), {}], - - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['as-set'],), {}], - ['rpsl_pks', ({'AS-THIRDLEVEL'},), {}], - - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['as-set'],), {}], - ['rpsl_pks', ({'AS-4th-UNKNOWN'},), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["as-set", "route-set"],), {}], + ["rpsl_pks", ({"AS-FIRSTLEVEL"},), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["as-set"],), {}], + ["rpsl_pks", ({"AS-2nd-UNKNOWN", "AS-SECONDLEVEL"},), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["as-set"],), {}], + ["rpsl_pks", ({"AS-THIRDLEVEL"},), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["as-set"],), {}], + ["rpsl_pks", ({"AS-4th-UNKNOWN"},), {}], ] mock_dq.reset_mock() - result = resolver.members_for_set('AS-FIRSTLEVEL', depth=1, recursive=True) - assert result == ['AS-2nd-UNKNOWN', 'AS-SECONDLEVEL', 'AS65547'] + result = resolver.members_for_set("AS-FIRSTLEVEL", depth=1, recursive=True) + assert result == ["AS-2nd-UNKNOWN", "AS-SECONDLEVEL", "AS65547"] mock_dq.reset_mock() mock_dh.execute_query = lambda query, refresh_on_error=False: iter([]) - result = resolver.members_for_set('AS-NOTEXIST', recursive=True) + result = resolver.members_for_set("AS-NOTEXIST", recursive=True) assert not result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['as-set', 'route-set'],), {}], - ['rpsl_pks', ({'AS-NOTEXIST'},), {}] + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["as-set", "route-set"],), {}], + ["rpsl_pks", ({"AS-NOTEXIST"},), {}], ] mock_dq.reset_mock() mock_dh.execute_query = lambda query, refresh_on_error=False: iter([]) - result = resolver.members_for_set('AS-NOTEXIST', recursive=True, root_source='ROOT') + result = resolver.members_for_set("AS-NOTEXIST", recursive=True, root_source="ROOT") assert not result assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['as-set', 'route-set'],), {}], - ['rpsl_pks', ({'AS-NOTEXIST'},), {}], - ['sources', (['ROOT'],), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["as-set", "route-set"],), {}], + ["rpsl_pks", ({"AS-NOTEXIST"},), {}], + ["sources", (["ROOT"],), {}], ] def test_route_set_members(self, prepare_resolver): @@ -487,57 +503,53 @@ def test_route_set_members(self, prepare_resolver): mock_query_result1 = [ { - 'pk': uuid.uuid4(), - 'rpsl_pk': 'RS-FIRSTLEVEL', - 'parsed_data': {'as-set': 'RS-FIRSTLEVEL', - 'members': ['RS-SECONDLEVEL', 'RS-2nd-UNKNOWN']}, - 'object_text': 'text', - 'object_class': 'route-set', - 'source': 'TEST1', + "pk": uuid.uuid4(), + "rpsl_pk": "RS-FIRSTLEVEL", + "parsed_data": {"as-set": "RS-FIRSTLEVEL", "members": ["RS-SECONDLEVEL", "RS-2nd-UNKNOWN"]}, + "object_text": "text", + "object_class": "route-set", + "source": "TEST1", }, ] mock_query_result2 = [ { - 'pk': uuid.uuid4(), - 'rpsl_pk': 'RS-SECONDLEVEL', - 'parsed_data': {'as-set': 'RS-SECONDLEVEL', - 'members': [ - 'AS-REFERRED', '192.0.2.0/25', '192.0.2.0/26^32' - ]}, - 'object_text': 'text', - 'object_class': 'route-set', - 'source': 'TEST1', + "pk": uuid.uuid4(), + "rpsl_pk": "RS-SECONDLEVEL", + "parsed_data": { + "as-set": "RS-SECONDLEVEL", + "members": ["AS-REFERRED", "192.0.2.0/25", "192.0.2.0/26^32"], + }, + "object_text": "text", + "object_class": "route-set", + "source": "TEST1", }, ] mock_query_result3 = [ { - 'pk': uuid.uuid4(), - 'rpsl_pk': 'AS-REFERRED', - 'parsed_data': {'as-set': 'AS-REFERRED', - 'members': ['AS65545']}, - 'object_text': 'text', - 'object_class': 'as-set', - 'source': 'TEST2', + "pk": uuid.uuid4(), + "rpsl_pk": "AS-REFERRED", + "parsed_data": {"as-set": "AS-REFERRED", "members": ["AS65545"]}, + "object_text": "text", + "object_class": "as-set", + "source": "TEST2", }, ] mock_query_iterator = iter([mock_query_result1, mock_query_result2, mock_query_result3, []]) mock_dh.execute_query = lambda query, refresh_on_error=False: iter(next(mock_query_iterator)) - mock_preloader.routes_for_origins = Mock(return_value=['192.0.2.128/25']) + mock_preloader.routes_for_origins = Mock(return_value=["192.0.2.128/25"]) - result = resolver.members_for_set('RS-FIRSTLEVEL', recursive=True) - assert set(result) == {'192.0.2.0/26^32', '192.0.2.0/25', '192.0.2.128/25'} + result = resolver.members_for_set("RS-FIRSTLEVEL", recursive=True) + assert set(result) == {"192.0.2.0/26^32", "192.0.2.0/25", "192.0.2.128/25"} assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['as-set', 'route-set'],), {}], - ['rpsl_pks', ({'RS-FIRSTLEVEL'},), {}], - - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['as-set', 'route-set'],), {}], - ['rpsl_pks', ({'RS-SECONDLEVEL', 'RS-2nd-UNKNOWN'},), {}], - - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['as-set', 'route-set'],), {}], - ['rpsl_pks', ({'AS-REFERRED'},), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["as-set", "route-set"],), {}], + ["rpsl_pks", ({"RS-FIRSTLEVEL"},), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["as-set", "route-set"],), {}], + ["rpsl_pks", ({"RS-SECONDLEVEL", "RS-2nd-UNKNOWN"},), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["as-set", "route-set"],), {}], + ["rpsl_pks", ({"AS-REFERRED"},), {}], ] def test_as_route_set_mbrs_by_ref(self, prepare_resolver): @@ -546,197 +558,203 @@ def test_as_route_set_mbrs_by_ref(self, prepare_resolver): mock_query_result1 = [ { # This route-set is intentionally misnamed RRS, as invalid names occur in real life. - 'pk': uuid.uuid4(), - 'rpsl_pk': 'RRS-TEST', - 'parsed_data': {'route-set': 'RRS-TEST', 'members': ['192.0.2.0/32'], - 'mp-members': ['2001:db8::/32'], 'mbrs-by-ref': ['MNT-TEST']}, - 'object_text': 'text', - 'object_class': 'route-set', - 'source': 'TEST1', + "pk": uuid.uuid4(), + "rpsl_pk": "RRS-TEST", + "parsed_data": { + "route-set": "RRS-TEST", + "members": ["192.0.2.0/32"], + "mp-members": ["2001:db8::/32"], + "mbrs-by-ref": ["MNT-TEST"], + }, + "object_text": "text", + "object_class": "route-set", + "source": "TEST1", }, ] mock_query_result2 = [ { - 'pk': uuid.uuid4(), - 'rpsl_pk': '192.0.2.0/24,AS65544', - 'parsed_data': {'route': '192.0.2.0/24', 'member-of': 'rrs-test', - 'mnt-by': ['FOO', 'MNT-TEST']}, - 'object_text': 'text', - 'object_class': 'route', - 'source': 'TEST1', + "pk": uuid.uuid4(), + "rpsl_pk": "192.0.2.0/24,AS65544", + "parsed_data": { + "route": "192.0.2.0/24", + "member-of": "rrs-test", + "mnt-by": ["FOO", "MNT-TEST"], + }, + "object_text": "text", + "object_class": "route", + "source": "TEST1", }, ] mock_query_iterator = iter([mock_query_result1, mock_query_result2, [], [], []]) mock_dh.execute_query = lambda query, refresh_on_error=False: iter(next(mock_query_iterator)) - result = resolver.members_for_set('RRS-TEST', recursive=True) - assert result == ['192.0.2.0/24', '192.0.2.0/32', '2001:db8::/32'] + result = resolver.members_for_set("RRS-TEST", recursive=True) + assert result == ["192.0.2.0/24", "192.0.2.0/32", "2001:db8::/32"] assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['as-set', 'route-set'],), {}], - ['rpsl_pks', ({'RRS-TEST'},), {}], - - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['route', 'route6'],), {}], - ['lookup_attrs_in', (['member-of'], ['RRS-TEST']), {}], - ['lookup_attrs_in', (['mnt-by'], ['MNT-TEST']), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["as-set", "route-set"],), {}], + ["rpsl_pks", ({"RRS-TEST"},), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["route", "route6"],), {}], + ["lookup_attrs_in", (["member-of"], ["RRS-TEST"]), {}], + ["lookup_attrs_in", (["mnt-by"], ["MNT-TEST"]), {}], ] mock_dq.reset_mock() # Disable maintainer check - mock_query_result1[0]['parsed_data']['mbrs-by-ref'] = ['ANY'] + mock_query_result1[0]["parsed_data"]["mbrs-by-ref"] = ["ANY"] mock_query_iterator = iter([mock_query_result1, mock_query_result2, [], [], []]) - result = resolver.members_for_set('RRS-TEST', recursive=True) - assert result == ['192.0.2.0/24', '192.0.2.0/32', '2001:db8::/32'] + result = resolver.members_for_set("RRS-TEST", recursive=True) + assert result == ["192.0.2.0/24", "192.0.2.0/32", "2001:db8::/32"] assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['as-set', 'route-set'],), {}], - ['rpsl_pks', ({'RRS-TEST'},), {}], - - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['route', 'route6'],), {}], - ['lookup_attrs_in', (['member-of'], ['RRS-TEST']), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["as-set", "route-set"],), {}], + ["rpsl_pks", ({"RRS-TEST"},), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["route", "route6"],), {}], + ["lookup_attrs_in", (["member-of"], ["RRS-TEST"]), {}], ] - def test_route_set_compatibility_ipv4_only_route_set_members(self, prepare_resolver, - config_override): + def test_route_set_compatibility_ipv4_only_route_set_members(self, prepare_resolver, config_override): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver mock_query_result = [ { - 'pk': uuid.uuid4(), - 'rpsl_pk': 'RS-TEST', - 'parsed_data': { - 'route-set': 'RS-TEST', - 'members': ['192.0.2.0/32'], - 'mp-members': ['192.0.2.1/32', '2001:db8::/32', 'RS-OTHER'] + "pk": uuid.uuid4(), + "rpsl_pk": "RS-TEST", + "parsed_data": { + "route-set": "RS-TEST", + "members": ["192.0.2.0/32"], + "mp-members": ["192.0.2.1/32", "2001:db8::/32", "RS-OTHER"], }, - 'object_text': 'text', - 'object_class': 'route-set', - 'source': 'TEST1', + "object_text": "text", + "object_class": "route-set", + "source": "TEST1", }, ] mock_dh.execute_query = lambda query, refresh_on_error=False: mock_query_result - result = resolver.members_for_set('RS-TEST', recursive=False) - assert result == ['192.0.2.0/32', '192.0.2.1/32', '2001:db8::/32', 'RS-OTHER'] + result = resolver.members_for_set("RS-TEST", recursive=False) + assert result == ["192.0.2.0/32", "192.0.2.1/32", "2001:db8::/32", "RS-OTHER"] - config_override({ - 'compatibility': {'ipv4_only_route_set_members': True}, - }) + config_override( + { + "compatibility": {"ipv4_only_route_set_members": True}, + } + ) - result = resolver.members_for_set('RS-TEST', recursive=False) - assert result == ['192.0.2.0/32', '192.0.2.1/32', 'RS-OTHER'] + result = resolver.members_for_set("RS-TEST", recursive=False) + assert result == ["192.0.2.0/32", "192.0.2.1/32", "RS-OTHER"] def test_members_for_set_per_source(self, prepare_resolver): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver - mock_query_result = iter([ + mock_query_result = iter( [ - { - 'rpsl_pk': 'AS-TEST', - 'source': 'TEST1', - }, - { - 'rpsl_pk': 'AS-TEST', - 'source': 'TEST2', - }, - ], [ - { - 'pk': uuid.uuid4(), - 'rpsl_pk': 'AS-TEST', - 'parsed_data': {'as-set': 'AS-TEST', - 'members': ['AS65547', 'AS-SECONDLEVEL']}, - 'object_text': 'text', - 'object_class': 'as-set', - 'source': 'TEST1', - }, - ], [ - { - 'pk': uuid.uuid4(), - 'rpsl_pk': 'AS-SECONDLEVEL', - 'parsed_data': {'as-set': 'AS-SECONDLEVEL', - 'members': ['AS65548']}, - 'object_text': 'text', - 'object_class': 'as-set', - 'source': 'TEST1', - }, - ], [ - { - 'pk': uuid.uuid4(), - 'rpsl_pk': 'AS-TEST', - 'parsed_data': {'as-set': 'AS-TEST', - 'members': ['AS65549']}, - 'object_text': 'text', - 'object_class': 'as-set', - 'source': 'TEST2', - }, - ], - [], - ]) + [ + { + "rpsl_pk": "AS-TEST", + "source": "TEST1", + }, + { + "rpsl_pk": "AS-TEST", + "source": "TEST2", + }, + ], + [ + { + "pk": uuid.uuid4(), + "rpsl_pk": "AS-TEST", + "parsed_data": {"as-set": "AS-TEST", "members": ["AS65547", "AS-SECONDLEVEL"]}, + "object_text": "text", + "object_class": "as-set", + "source": "TEST1", + }, + ], + [ + { + "pk": uuid.uuid4(), + "rpsl_pk": "AS-SECONDLEVEL", + "parsed_data": {"as-set": "AS-SECONDLEVEL", "members": ["AS65548"]}, + "object_text": "text", + "object_class": "as-set", + "source": "TEST1", + }, + ], + [ + { + "pk": uuid.uuid4(), + "rpsl_pk": "AS-TEST", + "parsed_data": {"as-set": "AS-TEST", "members": ["AS65549"]}, + "object_text": "text", + "object_class": "as-set", + "source": "TEST2", + }, + ], + [], + ] + ) mock_dh.execute_query = lambda query, refresh_on_error=False: next(mock_query_result) - result = resolver.members_for_set_per_source('AS-TEST', recursive=True) - assert result == {'TEST1': ['AS65547', 'AS65548'], 'TEST2': ['AS65549']} + result = resolver.members_for_set_per_source("AS-TEST", recursive=True) + assert result == {"TEST1": ["AS65547", "AS65548"], "TEST2": ["AS65549"]} assert flatten_mock_calls(mock_dq) == [ - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['as-set', 'route-set'],), {}], - ['rpsl_pk', ('AS-TEST',), {}], - - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['as-set', 'route-set'],), {}], - ['rpsl_pks', ({'AS-TEST'},), {}], - ['sources', (['TEST1'],), {}], - - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['as-set'],), {}], - ['rpsl_pks', ({'AS-SECONDLEVEL'},), {}], - - ['sources', (['TEST1', 'TEST2'],), {}], - ['object_classes', (['as-set', 'route-set'],), {}], - ['rpsl_pks', ({'AS-TEST'},), {}], - ['sources', (['TEST2'],), {}] + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["as-set", "route-set"],), {}], + ["rpsl_pk", ("AS-TEST",), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["as-set", "route-set"],), {}], + ["rpsl_pks", ({"AS-TEST"},), {}], + ["sources", (["TEST1"],), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["as-set"],), {}], + ["rpsl_pks", ({"AS-SECONDLEVEL"},), {}], + ["sources", (["TEST1", "TEST2"],), {}], + ["object_classes", (["as-set", "route-set"],), {}], + ["rpsl_pks", ({"AS-TEST"},), {}], + ["sources", (["TEST2"],), {}], ] mock_dq.reset_mock() def test_database_status(self, monkeypatch, prepare_resolver, config_override): - config_override({ - 'rpki': {'roa_source': 'http://example.com/'}, - 'scopefilter': {'prefixes': ['192.0.2.0/24']}, - 'sources': { - 'TEST1': { - 'authoritative': True, - 'object_class_filter': ['route'], - 'scopefilter_excluded': True, - 'route_preference': 200, + config_override( + { + "rpki": {"roa_source": "http://example.com/"}, + "scopefilter": {"prefixes": ["192.0.2.0/24"]}, + "sources": { + "TEST1": { + "authoritative": True, + "object_class_filter": ["route"], + "scopefilter_excluded": True, + "route_preference": 200, + }, + "TEST2": {"rpki_excluded": True, "keep_journal": True}, }, - 'TEST2': {'rpki_excluded': True, 'keep_journal': True} - }, - 'sources_default': [], - }) + "sources_default": [], + } + ) mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver mock_dsq = Mock() - monkeypatch.setattr('irrd.server.query_resolver.DatabaseStatusQuery', lambda: mock_dsq) - monkeypatch.setattr('irrd.server.query_resolver.is_serial_synchronised', - lambda dh, s: False) + monkeypatch.setattr("irrd.server.query_resolver.DatabaseStatusQuery", lambda: mock_dsq) + monkeypatch.setattr("irrd.server.query_resolver.is_serial_synchronised", lambda dh, s: False) mock_query_result = [ { - 'source': 'TEST1', - 'serial_oldest_journal': 10, - 'serial_newest_journal': 10, - 'serial_last_export': 10, - 'serial_newest_mirror': 500, - 'updated': datetime.datetime(2020, 1, 1, tzinfo=timezone('UTC')), + "source": "TEST1", + "serial_oldest_journal": 10, + "serial_newest_journal": 10, + "serial_last_export": 10, + "serial_newest_mirror": 500, + "updated": datetime.datetime(2020, 1, 1, tzinfo=timezone("UTC")), }, { - 'source': 'TEST2', - 'serial_oldest_journal': None, - 'serial_newest_journal': None, - 'serial_last_export': None, - 'serial_newest_mirror': 20, - 'updated': datetime.datetime(2020, 1, 1, tzinfo=timezone('UTC')), + "source": "TEST2", + "serial_oldest_journal": None, + "serial_newest_journal": None, + "serial_last_export": None, + "serial_newest_mirror": 20, + "updated": datetime.datetime(2020, 1, 1, tzinfo=timezone("UTC")), }, ] mock_dh.execute_query = lambda query, refresh_on_error=False: mock_query_result @@ -785,30 +803,26 @@ def test_database_status(self, monkeypatch, prepare_resolver, config_override): ), ] ) - assert flatten_mock_calls(mock_dsq) == [ - ['sources', (['TEST1', 'TEST2'],), {}] - ] + assert flatten_mock_calls(mock_dsq) == [["sources", (["TEST1", "TEST2"],), {}]] mock_dsq.reset_mock() mock_query_result = mock_query_result[:1] - result = resolver.database_status(['TEST1', 'TEST-INVALID']) + result = resolver.database_status(["TEST1", "TEST-INVALID"]) assert result == OrderedDict( [ expected_test1_result, ("TEST-INVALID", OrderedDict([("error", "Unknown source")])), ] ) - assert flatten_mock_calls(mock_dsq) == [ - ['sources', (['TEST1', 'TEST-INVALID'],), {}] - ] + assert flatten_mock_calls(mock_dsq) == [["sources", (["TEST1", "TEST-INVALID"],), {}]] def test_object_template(self, prepare_resolver): mock_dq, mock_dh, mock_preloader, mock_query_result, resolver = prepare_resolver mock_dh.reset_mock() - result = resolver.rpsl_object_template('aut-num') - assert 'aut-num:[mandatory][single][primary/look-upkey]' in result.replace(' ', '') + result = resolver.rpsl_object_template("aut-num") + assert "aut-num:[mandatory][single][primary/look-upkey]" in result.replace(" ", "") mock_dh.reset_mock() with pytest.raises(InvalidQueryException): - resolver.rpsl_object_template('does-not-exist') + resolver.rpsl_object_template("does-not-exist") diff --git a/irrd/server/whois/query_parser.py b/irrd/server/whois/query_parser.py index d478c7935..fe467a185 100644 --- a/irrd/server/whois/query_parser.py +++ b/irrd/server/whois/query_parser.py @@ -7,17 +7,26 @@ from ordered_set import OrderedSet from irrd import __version__ -from irrd.conf import get_setting, RPKI_IRR_PSEUDO_SOURCE, SOCKET_DEFAULT_TIMEOUT +from irrd.conf import RPKI_IRR_PSEUDO_SOURCE, SOCKET_DEFAULT_TIMEOUT, get_setting from irrd.mirroring.nrtm_generator import NRTMGenerator, NRTMGeneratorException from irrd.rpki.status import RPKIStatus -from irrd.rpsl.rpsl_objects import (OBJECT_CLASS_MAPPING, RPKI_RELEVANT_OBJECT_CLASSES) -from irrd.server.query_resolver import QueryResolver, RouteLookupType, InvalidQueryException +from irrd.rpsl.rpsl_objects import OBJECT_CLASS_MAPPING, RPKI_RELEVANT_OBJECT_CLASSES +from irrd.server.query_resolver import ( + InvalidQueryException, + QueryResolver, + RouteLookupType, +) from irrd.storage.database_handler import DatabaseHandler, RPSLDatabaseResponse from irrd.storage.preload import Preloader from irrd.storage.queries import DatabaseStatusQuery -from irrd.utils.validators import parse_as_number, ValidationError -from .query_response import WhoisQueryResponseType, WhoisQueryResponseMode, WhoisQueryResponse +from irrd.utils.validators import ValidationError, parse_as_number + from ..access_check import is_client_permitted +from .query_response import ( + WhoisQueryResponse, + WhoisQueryResponseMode, + WhoisQueryResponseType, +) logger = logging.getLogger(__name__) @@ -35,8 +44,9 @@ class WhoisQueryParser: handle_query() being called for each individual query. """ - def __init__(self, client_ip: str, client_str: str, preloader: Preloader, - database_handler: DatabaseHandler) -> None: + def __init__( + self, client_ip: str, client_str: str, preloader: Preloader, database_handler: DatabaseHandler + ) -> None: self.multiple_command_mode = False self.timeout = SOCKET_DEFAULT_TIMEOUT self.key_fields_only = False @@ -59,25 +69,29 @@ def handle_query(self, query: str) -> WhoisQueryResponse: return WhoisQueryResponse( response_type=WhoisQueryResponseType.ERROR_USER, mode=WhoisQueryResponseMode.IRRD, - result='Queries may not contain null bytes', + result="Queries may not contain null bytes", ) - if query.startswith('!'): + if query.startswith("!"): try: return self.handle_irrd_command(query[1:]) except InvalidQueryException as exc: - logger.info(f'{self.client_str}: encountered parsing error while parsing query "{query}": {exc}') + logger.info( + f'{self.client_str}: encountered parsing error while parsing query "{query}": {exc}' + ) return WhoisQueryResponse( response_type=WhoisQueryResponseType.ERROR_USER, mode=WhoisQueryResponseMode.IRRD, - result=str(exc) + result=str(exc), ) except Exception as exc: - logger.error(f'An exception occurred while processing whois query "{query}": {exc}', exc_info=exc) + logger.error( + f'An exception occurred while processing whois query "{query}": {exc}', exc_info=exc + ) return WhoisQueryResponse( response_type=WhoisQueryResponseType.ERROR_INTERNAL, mode=WhoisQueryResponseMode.IRRD, - result='An internal error occurred while processing this query.' + result="An internal error occurred while processing this query.", ) try: @@ -87,88 +101,94 @@ def handle_query(self, query: str) -> WhoisQueryResponse: return WhoisQueryResponse( response_type=WhoisQueryResponseType.ERROR_USER, mode=WhoisQueryResponseMode.RIPE, - result=str(exc) + result=str(exc), ) except Exception as exc: logger.error(f'An exception occurred while processing whois query "{query}": {exc}', exc_info=exc) return WhoisQueryResponse( response_type=WhoisQueryResponseType.ERROR_INTERNAL, mode=WhoisQueryResponseMode.RIPE, - result='An internal error occurred while processing this query.' + result="An internal error occurred while processing this query.", ) def handle_irrd_command(self, full_command: str) -> WhoisQueryResponse: - """Handle an IRRD-style query. full_command should not include the first exclamation mark. """ + """Handle an IRRD-style query. full_command should not include the first exclamation mark.""" if not full_command: - raise InvalidQueryException('Missing IRRD command') + raise InvalidQueryException("Missing IRRD command") command = full_command[0] parameter = full_command[1:] response_type = WhoisQueryResponseType.SUCCESS result = None # A is not tested here because it is already handled in handle_irrd_routes_for_as_set - queries_with_parameter = list('tg6ijmnors') + queries_with_parameter = list("tg6ijmnors") if command in queries_with_parameter and not parameter: - raise InvalidQueryException(f'Missing parameter for {command} query') + raise InvalidQueryException(f"Missing parameter for {command} query") - if command == '!': + if command == "!": self.multiple_command_mode = True result = None response_type = WhoisQueryResponseType.NO_RESPONSE - elif full_command.upper() == 'FNO-RPKI-FILTER': + elif full_command.upper() == "FNO-RPKI-FILTER": self.query_resolver.disable_rpki_filter() - result = 'Filtering out RPKI invalids is disabled for !r and RIPE style ' \ - 'queries for the rest of this connection.' - elif full_command.upper() == 'FNO-SCOPE-FILTER': + result = ( + "Filtering out RPKI invalids is disabled for !r and RIPE style " + "queries for the rest of this connection." + ) + elif full_command.upper() == "FNO-SCOPE-FILTER": self.query_resolver.disable_out_of_scope_filter() - result = 'Filtering out out-of-scope objects is disabled for !r and RIPE style ' \ - 'queries for the rest of this connection.' - elif full_command.upper() == 'FNO-ROUTE-PREFERENCE-FILTER': + result = ( + "Filtering out out-of-scope objects is disabled for !r and RIPE style " + "queries for the rest of this connection." + ) + elif full_command.upper() == "FNO-ROUTE-PREFERENCE-FILTER": self.query_resolver.disable_route_preference_filter() - result = 'Filtering out objects suppressed due to route preference is disabled for ' \ - '!r and RIPE style queries for the rest of this connection.' - elif command == 'v': + result = ( + "Filtering out objects suppressed due to route preference is disabled for " + "!r and RIPE style queries for the rest of this connection." + ) + elif command == "v": result = self.handle_irrd_version() - elif command == 't': + elif command == "t": self.handle_irrd_timeout_update(parameter) - elif command == 'g': + elif command == "g": result = self.handle_irrd_routes_for_origin_v4(parameter) if not result: response_type = WhoisQueryResponseType.KEY_NOT_FOUND - elif command == '6': + elif command == "6": result = self.handle_irrd_routes_for_origin_v6(parameter) if not result: response_type = WhoisQueryResponseType.KEY_NOT_FOUND - elif command == 'a': + elif command == "a": result = self.handle_irrd_routes_for_as_set(parameter) if not result: response_type = WhoisQueryResponseType.KEY_NOT_FOUND - elif command == 'i': + elif command == "i": result = self.handle_irrd_set_members(parameter) if not result: response_type = WhoisQueryResponseType.KEY_NOT_FOUND - elif command == 'j': + elif command == "j": result = self.handle_irrd_database_serial_range(parameter) - elif command == 'J': + elif command == "J": result = self.handle_irrd_database_status(parameter) - elif command == 'm': + elif command == "m": result = self.handle_irrd_exact_key(parameter) if not result: response_type = WhoisQueryResponseType.KEY_NOT_FOUND - elif command == 'n': + elif command == "n": self.handle_user_agent(parameter) - elif command == 'o': - result = self.handle_inverse_attr_search('mnt-by', parameter) + elif command == "o": + result = self.handle_inverse_attr_search("mnt-by", parameter) if not result: response_type = WhoisQueryResponseType.KEY_NOT_FOUND - elif command == 'r': + elif command == "r": result = self.handle_irrd_route_search(parameter) if not result: response_type = WhoisQueryResponseType.KEY_NOT_FOUND - elif command == 's': + elif command == "s": result = self.handle_irrd_sources_list(parameter) else: - raise InvalidQueryException(f'Unrecognised command: {command}') + raise InvalidQueryException(f"Unrecognised command: {command}") return WhoisQueryResponse( response_type=response_type, @@ -181,12 +201,12 @@ def handle_irrd_timeout_update(self, timeout: str) -> None: try: timeout_value = int(timeout) except ValueError: - raise InvalidQueryException(f'Invalid value for timeout: {timeout}') + raise InvalidQueryException(f"Invalid value for timeout: {timeout}") if timeout_value > 0 and timeout_value <= 1000: self.timeout = timeout_value else: - raise InvalidQueryException(f'Invalid value for timeout: {timeout}') + raise InvalidQueryException(f"Invalid value for timeout: {timeout}") def handle_irrd_routes_for_origin_v4(self, origin: str) -> str: """!g query - find all originating IPv4 prefixes from an origin, e.g. !gAS65537""" @@ -196,7 +216,7 @@ def handle_irrd_routes_for_origin_v6(self, origin: str) -> str: """!6 query - find all originating IPv6 prefixes from an origin, e.g. !6as65537""" return self._routes_for_origin(origin, 6) - def _routes_for_origin(self, origin: str, ip_version: Optional[int]=None) -> str: + def _routes_for_origin(self, origin: str, ip_version: Optional[int] = None) -> str: """ Resolve all route(6)s prefixes for an origin, returning a space-separated list of all originating prefixes, not including duplicates. @@ -207,25 +227,25 @@ def _routes_for_origin(self, origin: str, ip_version: Optional[int]=None) -> str raise InvalidQueryException(str(ve)) prefixes = self.query_resolver.routes_for_origin(origin_formatted, ip_version) - return ' '.join(prefixes) + return " ".join(prefixes) def handle_irrd_routes_for_as_set(self, set_name: str) -> str: """ !a query - find all originating prefixes for all members of an AS-set, e.g. !a4AS-FOO or !a6AS-FOO """ ip_version: Optional[int] = None - if set_name.startswith('4'): + if set_name.startswith("4"): set_name = set_name[1:] ip_version = 4 - elif set_name.startswith('6'): + elif set_name.startswith("6"): set_name = set_name[1:] ip_version = 6 if not set_name: - raise InvalidQueryException('Missing required set name for A query') + raise InvalidQueryException("Missing required set name for A query") prefixes = self.query_resolver.routes_for_as_set(set_name, ip_version) - return ' '.join(prefixes) + return " ".join(prefixes) def handle_irrd_set_members(self, parameter: str) -> str: """ @@ -233,12 +253,12 @@ def handle_irrd_set_members(self, parameter: str) -> str: e.g. !iAS-FOO for non-recursive, !iAS-FOO,1 for recursive """ recursive = False - if parameter.endswith(',1'): + if parameter.endswith(",1"): recursive = True parameter = parameter[:-2] members = self.query_resolver.members_for_set(parameter, recursive=recursive) - return ' '.join(members) + return " ".join(members) def handle_irrd_database_serial_range(self, parameter: str) -> str: """ @@ -246,50 +266,54 @@ def handle_irrd_database_serial_range(self, parameter: str) -> str: This query is legacy and only available in whois, so resolved directly here instead of in the query resolver. """ - if parameter == '-*': - sources = self.query_resolver.sources_default if self.query_resolver.sources_default else self.query_resolver.all_valid_sources + if parameter == "-*": + sources = ( + self.query_resolver.sources_default + if self.query_resolver.sources_default + else self.query_resolver.all_valid_sources + ) else: - sources = [s.upper() for s in parameter.split(',')] + sources = [s.upper() for s in parameter.split(",")] invalid_sources = [s for s in sources if s not in self.query_resolver.all_valid_sources] query = DatabaseStatusQuery().sources(sources) query_results = self.database_handler.execute_query(query, refresh_on_error=True) - result_txt = '' + result_txt = "" for query_result in query_results: - source = query_result['source'].upper() - keep_journal = 'Y' if get_setting(f'sources.{source}.keep_journal') else 'N' - serial_newest = query_result['serial_newest_mirror'] + source = query_result["source"].upper() + keep_journal = "Y" if get_setting(f"sources.{source}.keep_journal") else "N" + serial_newest = query_result["serial_newest_mirror"] fields = [ source, keep_journal, - f'0-{serial_newest}' if serial_newest else '-', + f"0-{serial_newest}" if serial_newest else "-", ] - if query_result['serial_last_export']: - fields.append(str(query_result['serial_last_export'])) - result_txt += ':'.join(fields) + '\n' + if query_result["serial_last_export"]: + fields.append(str(query_result["serial_last_export"])) + result_txt += ":".join(fields) + "\n" for invalid_source in invalid_sources: - result_txt += f'{invalid_source.upper()}:X:Database unknown\n' + result_txt += f"{invalid_source.upper()}:X:Database unknown\n" return result_txt.strip() def handle_irrd_database_status(self, parameter: str) -> str: """!J query - database status""" - if parameter == '-*': + if parameter == "-*": sources = None else: - sources = [s.upper() for s in parameter.split(',')] + sources = [s.upper() for s in parameter.split(",")] results = self.query_resolver.database_status(sources) return ujson.dumps(results, indent=4) def handle_irrd_exact_key(self, parameter: str): """!m query - exact object key lookup, e.g. !maut-num,AS65537""" try: - object_class, rpsl_pk = parameter.split(',', maxsplit=1) + object_class, rpsl_pk = parameter.split(",", maxsplit=1) except ValueError: - raise InvalidQueryException(f'Invalid argument for object lookup: {parameter}') + raise InvalidQueryException(f"Invalid argument for object lookup: {parameter}") - if object_class in ['route', 'route6']: - rpsl_pk = rpsl_pk.upper().replace(' ', '').replace('-', '') + if object_class in ["route", "route6"]: + rpsl_pk = rpsl_pk.upper().replace(" ", "").replace("-", "") query = self.query_resolver.key_lookup(object_class, rpsl_pk) return self._flatten_query_output(query) @@ -303,31 +327,31 @@ def handle_irrd_route_search(self, parameter: str): !r192.0.2.0/24,M returns all more specific objects, not including exact """ option: Optional[str] = None - if ',' in parameter: - address, option = parameter.split(',') + if "," in parameter: + address, option = parameter.split(",") else: address = parameter try: address = IP(address) except ValueError: - raise InvalidQueryException(f'Invalid input for route search: {parameter}') + raise InvalidQueryException(f"Invalid input for route search: {parameter}") lookup_types = { None: RouteLookupType.EXACT, - 'o': RouteLookupType.EXACT, - 'l': RouteLookupType.LESS_SPECIFIC_ONE_LEVEL, - 'L': RouteLookupType.LESS_SPECIFIC_WITH_EXACT, - 'M': RouteLookupType.MORE_SPECIFIC_WITHOUT_EXACT, + "o": RouteLookupType.EXACT, + "l": RouteLookupType.LESS_SPECIFIC_ONE_LEVEL, + "L": RouteLookupType.LESS_SPECIFIC_WITH_EXACT, + "M": RouteLookupType.MORE_SPECIFIC_WITHOUT_EXACT, } try: lookup_type = lookup_types[option] except KeyError: - raise InvalidQueryException(f'Invalid route search option: {option}') + raise InvalidQueryException(f"Invalid route search option: {option}") result = self.query_resolver.route_search(address, lookup_type) - if option == 'o': - prefixes = [r['parsed_data']['origin'] for r in result] - return ' '.join(prefixes) + if option == "o": + prefixes = [r["parsed_data"]["origin"] for r in result] + return " ".join(prefixes) return self._flatten_query_output(result) def handle_irrd_sources_list(self, parameter: str) -> Optional[str]: @@ -336,67 +360,67 @@ def handle_irrd_sources_list(self, parameter: str) -> Optional[str]: !s-lc returns all enabled sources, space separated !sripe,nttcom limits sources to ripe and nttcom """ - if parameter == '-lc': - return ','.join(self.query_resolver.sources) + if parameter == "-lc": + return ",".join(self.query_resolver.sources) - sources = parameter.upper().split(',') + sources = parameter.upper().split(",") self.query_resolver.set_query_sources(sources) return None def handle_irrd_version(self): """!v query - return version""" - return f'IRRd -- version {__version__}' + return f"IRRd -- version {__version__}" def handle_ripe_command(self, full_query: str) -> WhoisQueryResponse: """ Process RIPE-style queries. Any query that is not explicitly an IRRD-style query (i.e. starts with exclamation mark) is presumed to be a RIPE query. """ - full_query = re.sub(' +', ' ', full_query) - components = full_query.strip().split(' ') + full_query = re.sub(" +", " ", full_query) + components = full_query.strip().split(" ") result = None response_type = WhoisQueryResponseType.SUCCESS remove_auth_hashes = True while len(components): component = components.pop(0) - if component.startswith('-'): + if component.startswith("-"): command = component[1:] try: - if command == 'k': + if command == "k": self.multiple_command_mode = True - elif command in ['l', 'L', 'M', 'x']: + elif command in ["l", "L", "M", "x"]: result = self.handle_ripe_route_search(command, components.pop(0)) if not result: response_type = WhoisQueryResponseType.KEY_NOT_FOUND break - elif command == 'i': + elif command == "i": result = self.handle_inverse_attr_search(components.pop(0), components.pop(0)) if not result: response_type = WhoisQueryResponseType.KEY_NOT_FOUND break - elif command == 's': + elif command == "s": self.handle_ripe_sources_list(components.pop(0)) - elif command == 'a': + elif command == "a": self.handle_ripe_sources_list(None) - elif command == 'T': + elif command == "T": self.handle_ripe_restrict_object_class(components.pop(0)) - elif command == 't': + elif command == "t": result = self.handle_ripe_request_object_template(components.pop(0)) break - elif command == 'K': + elif command == "K": self.handle_ripe_key_fields_only() - elif command == 'V': + elif command == "V": self.handle_user_agent(components.pop(0)) - elif command == 'g': + elif command == "g": result = self.handle_nrtm_request(components.pop(0)) remove_auth_hashes = False - elif command in ['F', 'r']: + elif command in ["F", "r"]: continue # These flags disable recursion, but IRRd never performs recursion anyways else: - raise InvalidQueryException(f'Unrecognised flag/search: {command}') + raise InvalidQueryException(f"Unrecognised flag/search: {command}") except IndexError: - raise InvalidQueryException(f'Missing argument for flag/search: {command}') + raise InvalidQueryException(f"Missing argument for flag/search: {command}") else: # assume query to be a free text search result = self.handle_ripe_text_search(component) @@ -418,29 +442,29 @@ def handle_ripe_route_search(self, command: str, parameter: str) -> str: try: address = IP(parameter) except ValueError: - raise InvalidQueryException(f'Invalid input for route search: {parameter}') + raise InvalidQueryException(f"Invalid input for route search: {parameter}") lookup_types = { - 'x': RouteLookupType.EXACT, - 'l': RouteLookupType.LESS_SPECIFIC_ONE_LEVEL, - 'L': RouteLookupType.LESS_SPECIFIC_WITH_EXACT, - 'M': RouteLookupType.MORE_SPECIFIC_WITHOUT_EXACT, + "x": RouteLookupType.EXACT, + "l": RouteLookupType.LESS_SPECIFIC_ONE_LEVEL, + "L": RouteLookupType.LESS_SPECIFIC_WITH_EXACT, + "M": RouteLookupType.MORE_SPECIFIC_WITHOUT_EXACT, } lookup_type = lookup_types[command] result = self.query_resolver.route_search(address, lookup_type) return self._flatten_query_output(result) def handle_ripe_sources_list(self, sources_list: Optional[str]) -> None: - """-s/-a parameter - set sources list. Empty list enables all sources. """ + """-s/-a parameter - set sources list. Empty list enables all sources.""" if sources_list: - sources = sources_list.upper().split(',') + sources = sources_list.upper().split(",") self.query_resolver.set_query_sources(sources) else: self.query_resolver.set_query_sources(None) def handle_ripe_restrict_object_class(self, object_classes) -> None: """-T parameter - restrict object classes for this query, comma-seperated""" - self.query_resolver.set_object_class_filter_next_query(object_classes.split(',')) + self.query_resolver.set_object_class_filter_next_query(object_classes.split(",")) def handle_ripe_request_object_template(self, object_class) -> str: """-t query - return the RPSL template for an object class""" @@ -457,40 +481,47 @@ def handle_ripe_text_search(self, value: str) -> str: def handle_user_agent(self, user_agent: str): """-V/!n parameter/query - set a user agent for the client""" self.query_resolver.user_agent = user_agent - logger.info(f'{self.client_str}: user agent set to: {user_agent}') + logger.info(f"{self.client_str}: user agent set to: {user_agent}") def handle_nrtm_request(self, param): try: - source, version, serial_range = param.split(':') + source, version, serial_range = param.split(":") except ValueError: - raise InvalidQueryException('Invalid parameter: must contain three elements') + raise InvalidQueryException("Invalid parameter: must contain three elements") try: - serial_start, serial_end = serial_range.split('-') + serial_start, serial_end = serial_range.split("-") serial_start = int(serial_start) - if serial_end == 'LAST': + if serial_end == "LAST": serial_end = None else: serial_end = int(serial_end) except ValueError: - raise InvalidQueryException(f'Invalid serial range: {serial_range}') + raise InvalidQueryException(f"Invalid serial range: {serial_range}") - if version not in ['1', '3']: - raise InvalidQueryException(f'Invalid NRTM version: {version}') + if version not in ["1", "3"]: + raise InvalidQueryException(f"Invalid NRTM version: {version}") source = source.upper() if source not in self.query_resolver.all_valid_sources: - raise InvalidQueryException(f'Unknown source: {source}') + raise InvalidQueryException(f"Unknown source: {source}") - in_access_list = is_client_permitted(self.client_ip, f'sources.{source}.nrtm_access_list', log=False) - in_unfiltered_access_list = is_client_permitted(self.client_ip, f'sources.{source}.nrtm_access_list_unfiltered', log=False) + in_access_list = is_client_permitted(self.client_ip, f"sources.{source}.nrtm_access_list", log=False) + in_unfiltered_access_list = is_client_permitted( + self.client_ip, f"sources.{source}.nrtm_access_list_unfiltered", log=False + ) if not in_access_list and not in_unfiltered_access_list: - raise InvalidQueryException('Access denied') + raise InvalidQueryException("Access denied") try: return NRTMGenerator().generate( - source, version, serial_start, serial_end, self.database_handler, - remove_auth_hashes=not in_unfiltered_access_list) + source, + version, + serial_start, + serial_end, + self.database_handler, + remove_auth_hashes=not in_unfiltered_access_list, + ) except NRTMGeneratorException as nge: raise InvalidQueryException(str(nge)) @@ -511,35 +542,35 @@ def _flatten_query_output(self, query_response: RPSLDatabaseResponse) -> str: if self.key_fields_only: result = self._filter_key_fields(query_response) else: - result = '' + result = "" for obj in query_response: - result += obj['object_text'] + result += obj["object_text"] if ( - self.query_resolver.rpki_aware and - obj['source'] != RPKI_IRR_PSEUDO_SOURCE and - obj['object_class'] in RPKI_RELEVANT_OBJECT_CLASSES + self.query_resolver.rpki_aware + and obj["source"] != RPKI_IRR_PSEUDO_SOURCE + and obj["object_class"] in RPKI_RELEVANT_OBJECT_CLASSES ): - comment = '' - if obj['rpki_status'] == RPKIStatus.not_found: - comment = ' # No ROAs found, or RPKI validation not enabled for source' + comment = "" + if obj["rpki_status"] == RPKIStatus.not_found: + comment = " # No ROAs found, or RPKI validation not enabled for source" result += f'rpki-ov-state: {obj["rpki_status"].name}{comment}\n' - result += '\n' - return result.strip('\n\r') + result += "\n" + return result.strip("\n\r") def _filter_key_fields(self, query_response) -> str: results: OrderedSet[str] = OrderedSet() for obj in query_response: - result = '' - rpsl_object_class = OBJECT_CLASS_MAPPING[obj['object_class']] - fields_included = rpsl_object_class.pk_fields + ['members', 'mp-members'] + result = "" + rpsl_object_class = OBJECT_CLASS_MAPPING[obj["object_class"]] + fields_included = rpsl_object_class.pk_fields + ["members", "mp-members"] for field_name in fields_included: - field_data = obj['parsed_data'].get(field_name) + field_data = obj["parsed_data"].get(field_name) if field_data: if isinstance(field_data, list): for item in field_data: - result += f'{field_name}: {item}\n' + result += f"{field_name}: {item}\n" else: - result += f'{field_name}: {field_data}\n' + result += f"{field_name}: {field_data}\n" results.add(result) - return '\n'.join(results) + return "\n".join(results) diff --git a/irrd/server/whois/query_response.py b/irrd/server/whois/query_response.py index 9290edf35..d42ccd1a8 100644 --- a/irrd/server/whois/query_response.py +++ b/irrd/server/whois/query_response.py @@ -10,11 +10,12 @@ class WhoisQueryResponseType(Enum): KEY_NOT_FOUND is specific to IRRD-style. NO_RESPONSE means no response should be sent at all. """ - SUCCESS = 'success' - ERROR_INTERNAL = 'error_internal' - ERROR_USER = 'error_user' - KEY_NOT_FOUND = 'key_not_found' - NO_RESPONSE = 'no_response' + + SUCCESS = "success" + ERROR_INTERNAL = "error_internal" + ERROR_USER = "error_user" + KEY_NOT_FOUND = "key_not_found" + NO_RESPONSE = "no_response" ERROR_TYPES = [WhoisQueryResponseType.ERROR_INTERNAL, WhoisQueryResponseType.ERROR_USER] @@ -22,8 +23,9 @@ class WhoisQueryResponseType(Enum): class WhoisQueryResponseMode(Enum): """Response mode for queries - IRRD and RIPE queries have different output.""" - IRRD = 'irrd' - RIPE = 'ripe' + + IRRD = "irrd" + RIPE = "ripe" class WhoisQueryResponse: @@ -33,16 +35,17 @@ class WhoisQueryResponse: Based on the response_type and mode, can render a string of the complete response to send back to the user. """ + response_type: WhoisQueryResponseType = WhoisQueryResponseType.SUCCESS mode: WhoisQueryResponseMode = WhoisQueryResponseMode.RIPE result: Optional[str] = None def __init__( - self, - response_type: WhoisQueryResponseType, - mode: WhoisQueryResponseMode, - result: Optional[str], - remove_auth_hashes=True, + self, + response_type: WhoisQueryResponseType, + mode: WhoisQueryResponseMode, + result: Optional[str], + remove_auth_hashes=True, ) -> None: self.response_type = response_type self.mode = mode @@ -62,7 +65,9 @@ def generate_response(self) -> str: if response is not None: return response - raise RuntimeError(f'Unable to formulate response for {self.response_type} / {self.mode}: {self.result}') + raise RuntimeError( + f"Unable to formulate response for {self.response_type} / {self.mode}: {self.result}" + ) def clean_response(self): if self.remove_auth_hashes: @@ -72,15 +77,15 @@ def _generate_response_irrd(self) -> Optional[str]: if self.response_type == WhoisQueryResponseType.SUCCESS: if self.result: result_len = len(self.result) + 1 - return f'A{result_len}\n{self.result}\nC\n' + return f"A{result_len}\n{self.result}\nC\n" else: - return 'C\n' + return "C\n" elif self.response_type == WhoisQueryResponseType.KEY_NOT_FOUND: - return 'D\n' + return "D\n" elif self.response_type in ERROR_TYPES: - return f'F {self.result}\n' + return f"F {self.result}\n" elif self.response_type == WhoisQueryResponseType.NO_RESPONSE: - return '' + return "" return None def _generate_response_ripe(self) -> Optional[str]: @@ -89,10 +94,10 @@ def _generate_response_ripe(self) -> Optional[str]: # # https://www.ripe.net/manage-ips-and-asns/db/support/documentation/ripe-database-query-reference-manual#2-0-querying-the-ripe-database if self.response_type == WhoisQueryResponseType.SUCCESS: if self.result: - return self.result + '\n\n\n' - return '% No entries found for the selected source(s).\n\n\n' + return self.result + "\n\n\n" + return "% No entries found for the selected source(s).\n\n\n" elif self.response_type == WhoisQueryResponseType.KEY_NOT_FOUND: - return '% No entries found for the selected source(s).\n\n\n' + return "% No entries found for the selected source(s).\n\n\n" elif self.response_type in ERROR_TYPES: - return f'%% ERROR: {self.result}\n\n\n' + return f"%% ERROR: {self.result}\n\n\n" return None diff --git a/irrd/server/whois/server.py b/irrd/server/whois/server.py index c3ee5f5e4..cb65167a7 100644 --- a/irrd/server/whois/server.py +++ b/irrd/server/whois/server.py @@ -7,8 +7,8 @@ import threading import time -from IPy import IP from daemon.daemon import change_process_owner +from IPy import IP from setproctitle import setproctitle from irrd import ENV_MAIN_PROCESS_PID @@ -29,9 +29,9 @@ def start_whois_server(uid, gid): # pragma: no cover Start the whois server, listening forever. This function does not return, except after SIGTERM is received. """ - setproctitle('irrd-whois-server-listener') - address = (get_setting('server.whois.interface'), get_setting('server.whois.port')) - logger.info(f'Starting whois server on TCP {address}') + setproctitle("irrd-whois-server-listener") + address = (get_setting("server.whois.interface"), get_setting("server.whois.port")) + logger.info(f"Starting whois server on TCP {address}") server = WhoisTCPServer( server_address=address, uid=uid, @@ -43,11 +43,13 @@ def sigterm_handler(signum, frame): nonlocal server def shutdown(server): - logging.info('Whois server shutting down') + logging.info("Whois server shutting down") server.shutdown() server.server_close() + # Shutdown must be called from a thread to prevent blocking. threading.Thread(target=shutdown, args=(server,)).start() + signal.signal(signal.SIGTERM, sigterm_handler) server.serve_forever() @@ -62,6 +64,7 @@ class WhoisTCPServer(socketserver.TCPServer): # pragma: no cover from which a worker picks it up. The workers are responsible for the connection from then on. """ + allow_reuse_address = True request_queue_size = 50 @@ -73,7 +76,7 @@ def __init__(self, server_address, uid, gid, bind_and_activate=True): # noqa: N self.connection_queue = mp.Queue() self.workers = [] - for i in range(int(get_setting('server.whois.max_connections'))): + for i in range(int(get_setting("server.whois.max_connections"))): worker = WhoisWorker(self.connection_queue) worker.start() self.workers.append(worker) @@ -83,7 +86,7 @@ def process_request(self, request, client_address): self.connection_queue.put((request, client_address)) def handle_error(self, request, client_address): - logger.error(f'Error while handling request from {client_address}', exc_info=True) + logger.error(f"Error while handling request from {client_address}", exc_info=True) def shutdown(self): """ @@ -105,6 +108,7 @@ class WhoisWorker(mp.Process, socketserver.StreamRequestHandler): which are retrieved from a queue. After handling a connection, the process waits for the next connection from the queue.s """ + def __init__(self, connection_queue, *args, **kwargs): self.connection_queue = connection_queue # Note that StreamRequestHandler.__init__ is not called - the @@ -126,19 +130,23 @@ def run(self, keep_running=True) -> None: self.preloader = Preloader() self.database_handler = DatabaseHandler(readonly=True) except Exception as e: - logger.critical(f'Whois worker failed to initialise preloader or database, ' - f'unable to start, terminating IRRd, traceback follows: {e}', - exc_info=e) + logger.critical( + ( + "Whois worker failed to initialise preloader or database, " + f"unable to start, terminating IRRd, traceback follows: {e}" + ), + exc_info=e, + ) main_pid = os.getenv(ENV_MAIN_PROCESS_PID) if main_pid: # pragma: no cover os.kill(int(main_pid), signal.SIGTERM) else: - logger.error('Failed to terminate IRRd, unable to find main process PID') + logger.error("Failed to terminate IRRd, unable to find main process PID") return while True: try: - setproctitle('irrd-whois-worker') + setproctitle("irrd-whois-worker") self.request, self.client_address = self.connection_queue.get() self.setup() self.handle_connection() @@ -150,8 +158,7 @@ def run(self, keep_running=True) -> None: self.close_request() except Exception: # pragma: no cover pass - logger.error(f'Failed to handle whois connection, traceback follows: {e}', - exc_info=e) + logger.error(f"Failed to handle whois connection, traceback follows: {e}", exc_info=e) if not keep_running: break @@ -176,15 +183,16 @@ def handle_connection(self): When this method returns, the connection is closed. """ client_ip = self.client_address[0] - self.client_str = client_ip + ':' + str(self.client_address[1]) - setproctitle(f'irrd-whois-worker-{self.client_str}') + self.client_str = client_ip + ":" + str(self.client_address[1]) + setproctitle(f"irrd-whois-worker-{self.client_str}") if not self.is_client_permitted(client_ip): - self.wfile.write(b'%% Access denied') + self.wfile.write(b"%% Access denied") return - self.query_parser = WhoisQueryParser(client_ip, self.client_str, self.preloader, - self.database_handler) + self.query_parser = WhoisQueryParser( + client_ip, self.client_str, self.preloader, self.database_handler + ) data = True while data: @@ -193,11 +201,11 @@ def handle_connection(self): data = self.rfile.readline() timer.cancel() - query = data.decode('utf-8', errors='backslashreplace').strip() + query = data.decode("utf-8", errors="backslashreplace").strip() if not query: continue - logger.debug(f'{self.client_str}: processing query: {query}') + logger.debug(f"{self.client_str}: processing query: {query}") if not self.handle_query(query): return @@ -209,23 +217,25 @@ def handle_query(self, query: str) -> bool: True when more queries should be read. """ start_time = time.perf_counter() - if query.upper() == '!Q': - logger.debug(f'{self.client_str}: closed connection per request') + if query.upper() == "!Q": + logger.debug(f"{self.client_str}: closed connection per request") return False response = self.query_parser.handle_query(query) - response_bytes = response.generate_response().encode('utf-8') + response_bytes = response.generate_response().encode("utf-8") try: self.wfile.write(response_bytes) except OSError: return False elapsed = time.perf_counter() - start_time - logger.info(f'{self.client_str}: sent answer to query, elapsed {elapsed:.9f}s, ' - f'{len(response_bytes)} bytes: {query}') + logger.info( + f"{self.client_str}: sent answer to query, elapsed {elapsed:.9f}s, " + f"{len(response_bytes)} bytes: {query}" + ) if not self.query_parser.multiple_command_mode: - logger.debug(f'{self.client_str}: auto-closed connection') + logger.debug(f"{self.client_str}: auto-closed connection") return False return True @@ -233,4 +243,4 @@ def is_client_permitted(self, ip: str) -> bool: """ Check whether a client is permitted. """ - return is_client_permitted(ip, 'server.whois.access_list', default_deny=False) + return is_client_permitted(ip, "server.whois.access_list", default_deny=False) diff --git a/irrd/server/whois/tests/test_query_parser.py b/irrd/server/whois/tests/test_query_parser.py index b96aba003..08a0c1a4f 100644 --- a/irrd/server/whois/tests/test_query_parser.py +++ b/irrd/server/whois/tests/test_query_parser.py @@ -6,11 +6,16 @@ from irrd.mirroring.nrtm_generator import NRTMGeneratorException from irrd.rpki.status import RPKIStatus -from irrd.server.query_resolver import QueryResolver, RouteLookupType, InvalidQueryException +from irrd.server.query_resolver import ( + InvalidQueryException, + QueryResolver, + RouteLookupType, +) from irrd.storage.database_handler import DatabaseHandler from irrd.utils.test_utils import flatten_mock_calls + from ..query_parser import WhoisQueryParser -from ..query_response import WhoisQueryResponseType, WhoisQueryResponseMode +from ..query_response import WhoisQueryResponseMode, WhoisQueryResponseType # Note that these mock objects are not entirely valid RPSL objects, # as they are meant to test all the scenarios in the query parser. @@ -36,8 +41,15 @@ source: TEST2 """ -MOCK_ROUTE_COMBINED = MOCK_ROUTE1 + '\n' + MOCK_ROUTE2 + '\n' + MOCK_ROUTE3.strip() -MOCK_ROUTE_COMBINED_WITH_RPKI = MOCK_ROUTE1 + 'rpki-ov-state: not_found # No ROAs found, or RPKI validation not enabled for source\n\n' + MOCK_ROUTE2 + 'rpki-ov-state: valid\n\n' + MOCK_ROUTE3 + 'rpki-ov-state: valid' +MOCK_ROUTE_COMBINED = MOCK_ROUTE1 + "\n" + MOCK_ROUTE2 + "\n" + MOCK_ROUTE3.strip() +MOCK_ROUTE_COMBINED_WITH_RPKI = ( + MOCK_ROUTE1 + + "rpki-ov-state: not_found # No ROAs found, or RPKI validation not enabled for source\n\n" + + MOCK_ROUTE2 + + "rpki-ov-state: valid\n\n" + + MOCK_ROUTE3 + + "rpki-ov-state: valid" +) MOCK_ROUTE_COMBINED_KEY_FIELDS = """route: 192.0.2.0/25 @@ -52,37 +64,47 @@ MOCK_DATABASE_RESPONSE = [ { - 'pk': uuid.uuid4(), - 'rpsl_pk': '192.0.2.0/25,AS65547', - 'object_class': 'route', - 'parsed_data': { - 'route': '192.0.2.0/25', 'origin': 'AS65547', 'mnt-by': 'MNT-TEST', 'source': 'TEST1', - 'members': ['AS1, AS2'] + "pk": uuid.uuid4(), + "rpsl_pk": "192.0.2.0/25,AS65547", + "object_class": "route", + "parsed_data": { + "route": "192.0.2.0/25", + "origin": "AS65547", + "mnt-by": "MNT-TEST", + "source": "TEST1", + "members": ["AS1, AS2"], }, - 'object_text': MOCK_ROUTE1, - 'rpki_status': RPKIStatus.not_found, - 'source': 'TEST1', + "object_text": MOCK_ROUTE1, + "rpki_status": RPKIStatus.not_found, + "source": "TEST1", }, { - 'pk': uuid.uuid4(), - - 'rpsl_pk': '192.0.2.0/25,AS65544', - 'object_class': 'route', - 'parsed_data': {'route': '192.0.2.0/25', 'origin': 'AS65544', 'mnt-by': 'MNT-TEST', - 'source': 'TEST2'}, - 'object_text': MOCK_ROUTE2, - 'rpki_status': RPKIStatus.valid, - 'source': 'TEST2', + "pk": uuid.uuid4(), + "rpsl_pk": "192.0.2.0/25,AS65544", + "object_class": "route", + "parsed_data": { + "route": "192.0.2.0/25", + "origin": "AS65544", + "mnt-by": "MNT-TEST", + "source": "TEST2", + }, + "object_text": MOCK_ROUTE2, + "rpki_status": RPKIStatus.valid, + "source": "TEST2", }, { - 'pk': uuid.uuid4(), - 'rpsl_pk': '192.0.2.128/25,AS65545', - 'object_class': 'route', - 'parsed_data': {'route': '192.0.2.128/25', 'origin': 'AS65545', 'mnt-by': 'MNT-TEST', - 'source': 'TEST2'}, - 'object_text': MOCK_ROUTE3, - 'rpki_status': RPKIStatus.valid, - 'source': 'TEST2', + "pk": uuid.uuid4(), + "rpsl_pk": "192.0.2.128/25,AS65545", + "object_class": "route", + "parsed_data": { + "route": "192.0.2.128/25", + "origin": "AS65545", + "mnt-by": "MNT-TEST", + "source": "TEST2", + }, + "object_text": MOCK_ROUTE3, + "rpki_status": RPKIStatus.valid, + "source": "TEST2", }, ] @@ -91,11 +113,13 @@ def prepare_parser(monkeypatch, config_override): mock_query_resolver = Mock(spec=QueryResolver) mock_query_resolver.rpki_aware = False - monkeypatch.setattr('irrd.server.whois.query_parser.QueryResolver', - lambda preloader, database_handler: mock_query_resolver) + monkeypatch.setattr( + "irrd.server.whois.query_parser.QueryResolver", + lambda preloader, database_handler: mock_query_resolver, + ) mock_dh = Mock(spec=DatabaseHandler) - parser = WhoisQueryParser('127.0.0.1', '127.0.0.1:99999', None, mock_dh) + parser = WhoisQueryParser("127.0.0.1", "127.0.0.1:99999", None, mock_dh) yield mock_query_resolver, mock_dh, parser @@ -105,23 +129,23 @@ class TestWhoisQueryParserRIPE: def test_invalid_flag(self, prepare_parser): mock_query_resolver, mock_dh, parser = prepare_parser - response = parser.handle_query('-e foo') + response = parser.handle_query("-e foo") assert response.response_type == WhoisQueryResponseType.ERROR_USER assert response.mode == WhoisQueryResponseMode.RIPE - assert response.result == 'Unrecognised flag/search: e' + assert response.result == "Unrecognised flag/search: e" def test_null_bytes(self, prepare_parser): # #581 mock_query_resolver, mock_dh, parser = prepare_parser - response = parser.handle_query('\x00 foo') + response = parser.handle_query("\x00 foo") assert response.response_type == WhoisQueryResponseType.ERROR_USER assert response.mode == WhoisQueryResponseMode.IRRD - assert response.result == 'Queries may not contain null bytes' + assert response.result == "Queries may not contain null bytes" def test_keepalive(self, prepare_parser): mock_query_resolver, mock_dh, parser = prepare_parser - response = parser.handle_query('-k') + response = parser.handle_query("-k") assert response.response_type == WhoisQueryResponseType.SUCCESS assert response.mode == WhoisQueryResponseMode.RIPE assert not response.result @@ -135,19 +159,20 @@ def test_route_search_exact(self, prepare_parser): mock_query_resolver.route_search = Mock(return_value=MOCK_DATABASE_RESPONSE) parser.key_fields_only = True - response = parser.handle_query('-r -x 192.0.2.0/25') + response = parser.handle_query("-r -x 192.0.2.0/25") assert not parser.key_fields_only assert response.response_type == WhoisQueryResponseType.SUCCESS assert response.mode == WhoisQueryResponseMode.RIPE assert response.result == MOCK_ROUTE_COMBINED mock_query_resolver.route_search.assert_called_once_with( - IP('192.0.2.0/25'), RouteLookupType.EXACT, + IP("192.0.2.0/25"), + RouteLookupType.EXACT, ) mock_query_resolver.reset_mock() mock_query_resolver.route_search = Mock(return_value=[]) - response = parser.handle_query('-x 192.0.2.0/32') + response = parser.handle_query("-x 192.0.2.0/32") assert response.response_type == WhoisQueryResponseType.KEY_NOT_FOUND assert response.mode == WhoisQueryResponseMode.RIPE assert not response.result @@ -157,58 +182,61 @@ def test_route_search_less_specific_one_level(self, prepare_parser): mock_query_resolver, mock_dh, parser = prepare_parser mock_query_resolver.route_search = Mock(return_value=MOCK_DATABASE_RESPONSE) - response = parser.handle_query('-l 192.0.2.0/25') + response = parser.handle_query("-l 192.0.2.0/25") assert response.response_type == WhoisQueryResponseType.SUCCESS assert response.mode == WhoisQueryResponseMode.RIPE assert response.result == MOCK_ROUTE_COMBINED mock_query_resolver.route_search.assert_called_once_with( - IP('192.0.2.0/25'), RouteLookupType.LESS_SPECIFIC_ONE_LEVEL, + IP("192.0.2.0/25"), + RouteLookupType.LESS_SPECIFIC_ONE_LEVEL, ) def test_route_search_less_specific(self, prepare_parser): mock_query_resolver, mock_dh, parser = prepare_parser mock_query_resolver.route_search = Mock(return_value=MOCK_DATABASE_RESPONSE) - response = parser.handle_query('-L 192.0.2.0/25') + response = parser.handle_query("-L 192.0.2.0/25") assert response.response_type == WhoisQueryResponseType.SUCCESS assert response.mode == WhoisQueryResponseMode.RIPE assert response.result == MOCK_ROUTE_COMBINED mock_query_resolver.route_search.assert_called_once_with( - IP('192.0.2.0/25'), RouteLookupType.LESS_SPECIFIC_WITH_EXACT, + IP("192.0.2.0/25"), + RouteLookupType.LESS_SPECIFIC_WITH_EXACT, ) def test_route_search_more_specific(self, prepare_parser): mock_query_resolver, mock_dh, parser = prepare_parser mock_query_resolver.route_search = Mock(return_value=MOCK_DATABASE_RESPONSE) - response = parser.handle_query('-M 192.0.2.0/25') + response = parser.handle_query("-M 192.0.2.0/25") assert response.response_type == WhoisQueryResponseType.SUCCESS assert response.mode == WhoisQueryResponseMode.RIPE assert response.result == MOCK_ROUTE_COMBINED mock_query_resolver.route_search.assert_called_once_with( - IP('192.0.2.0/25'), RouteLookupType.MORE_SPECIFIC_WITHOUT_EXACT, + IP("192.0.2.0/25"), + RouteLookupType.MORE_SPECIFIC_WITHOUT_EXACT, ) def test_route_search_invalid_parameter(self, prepare_parser): mock_query_resolver, mock_dh, parser = prepare_parser - response = parser.handle_query('-x not-a-prefix') + response = parser.handle_query("-x not-a-prefix") assert response.response_type == WhoisQueryResponseType.ERROR_USER assert response.mode == WhoisQueryResponseMode.RIPE - assert response.result == 'Invalid input for route search: not-a-prefix' + assert response.result == "Invalid input for route search: not-a-prefix" def test_inverse_attribute_search(self, prepare_parser): mock_query_resolver, mock_dh, parser = prepare_parser mock_query_resolver.rpsl_attribute_search = Mock(return_value=MOCK_DATABASE_RESPONSE) - response = parser.handle_query('-i mnt-by MNT-TEST') + response = parser.handle_query("-i mnt-by MNT-TEST") assert response.response_type == WhoisQueryResponseType.SUCCESS assert response.mode == WhoisQueryResponseMode.RIPE assert response.result == MOCK_ROUTE_COMBINED - mock_query_resolver.rpsl_attribute_search.assert_called_once_with('mnt-by', 'MNT-TEST') + mock_query_resolver.rpsl_attribute_search.assert_called_once_with("mnt-by", "MNT-TEST") mock_query_resolver.rpsl_attribute_search = Mock(return_value=[]) - response = parser.handle_query('-i mnt-by MNT-NOT-EXISTING') + response = parser.handle_query("-i mnt-by MNT-NOT-EXISTING") assert response.response_type == WhoisQueryResponseType.KEY_NOT_FOUND assert response.mode == WhoisQueryResponseMode.RIPE assert not response.result @@ -217,17 +245,17 @@ def test_sources_list(self, prepare_parser): mock_query_resolver, mock_dh, parser = prepare_parser mock_query_resolver.set_query_sources = Mock() - response = parser.handle_query('-s test1') + response = parser.handle_query("-s test1") assert response.response_type == WhoisQueryResponseType.SUCCESS assert response.mode == WhoisQueryResponseMode.RIPE assert not response.result - mock_query_resolver.rpsl_attribute_search.set_query_sources(['TEST1']) + mock_query_resolver.rpsl_attribute_search.set_query_sources(["TEST1"]) def test_sources_all(self, prepare_parser): mock_query_resolver, mock_dh, parser = prepare_parser mock_query_resolver.set_query_sources = Mock() - response = parser.handle_query('-a') + response = parser.handle_query("-a") assert response.response_type == WhoisQueryResponseType.SUCCESS assert response.mode == WhoisQueryResponseMode.RIPE assert not response.result @@ -238,40 +266,41 @@ def test_restrict_object_class(self, prepare_parser): mock_query_resolver.set_object_class_filter_next_query = Mock() mock_query_resolver.rpsl_attribute_search = Mock(return_value=MOCK_DATABASE_RESPONSE) - response = parser.handle_query('-T route -i mnt-by MNT-TEST') + response = parser.handle_query("-T route -i mnt-by MNT-TEST") assert response.response_type == WhoisQueryResponseType.SUCCESS assert response.mode == WhoisQueryResponseMode.RIPE assert response.result == MOCK_ROUTE_COMBINED assert response.remove_auth_hashes - mock_query_resolver.rpsl_attribute_search.set_object_class_filter_next_query(['route']) - mock_query_resolver.rpsl_attribute_search.rpsl_attribute_search('mnt-by', 'MNT-TEST') + mock_query_resolver.rpsl_attribute_search.set_object_class_filter_next_query(["route"]) + mock_query_resolver.rpsl_attribute_search.rpsl_attribute_search("mnt-by", "MNT-TEST") def test_object_template(self, prepare_parser): mock_query_resolver, mock_dh, parser = prepare_parser - mock_query_resolver.rpsl_object_template = Mock(return_value='