Skip to content

Commit

Permalink
Wire up retry count config to NVD provider (#738)
Browse files Browse the repository at this point in the history
* adjust display behavior of http.get retry logging

Signed-off-by: Alex Goodman <wagoodman@users.noreply.github.com>

* wire up retry count config to nvd provider

Signed-off-by: Alex Goodman <wagoodman@users.noreply.github.com>

* fix linting

Signed-off-by: Alex Goodman <wagoodman@users.noreply.github.com>

* limit nvd to default backoff

Signed-off-by: Alex Goodman <wagoodman@users.noreply.github.com>

---------

Signed-off-by: Alex Goodman <wagoodman@users.noreply.github.com>
  • Loading branch information
wagoodman authored Nov 26, 2024
1 parent 3b15c63 commit d426c26
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 33 deletions.
2 changes: 2 additions & 0 deletions src/vunnel/providers/nvd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Config:
),
)
request_timeout: int = 125
request_retry_count: int = 10
api_key: Optional[str] = "env:NVD_API_KEY" # noqa: UP007
overrides_url: str = "https://github.com/anchore/nvd-data-overrides/archive/refs/heads/main.tar.gz"
overrides_enabled: bool = False
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(self, root: str, config: Config | None = None):
workspace=self.workspace,
schema=self.__schema__,
download_timeout=self.config.request_timeout,
download_retry_count=self.config.request_retry_count,
api_key=self.config.api_key,
logger=self.logger,
overrides_enabled=self.config.overrides_enabled,
Expand Down
18 changes: 16 additions & 2 deletions src/vunnel/providers/nvd/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,16 @@ class NvdAPI:
_max_results_per_page_: int = 2000
max_date_range_days: int = 120

def __init__(self, api_key: str | None = None, logger: logging.Logger | None = None, timeout: int = 30):
def __init__(
self,
api_key: str | None = None,
logger: logging.Logger | None = None,
timeout: int = 30,
retries: int = 10,
) -> None:
self.api_key = api_key
self.timeout = timeout
self.retries = retries

if not logger:
logger = logging.getLogger(self.__class__.__name__)
Expand Down Expand Up @@ -154,7 +161,14 @@ def _request(self, url: str, parameters: dict[str, str], headers: dict[str, str]

# NVD rate-limiting is detailed at https://nvd.nist.gov/developers/start-here and currently resets on a 30 second
# rolling window, so setting retry to start trying again after 30 seconds.
response = http.get(url, self.logger, backoff_in_seconds=30, params=payload_str, headers=headers, timeout=self.timeout)
response = http.get(
url,
self.logger,
params=payload_str,
headers=headers,
timeout=self.timeout,
retries=self.retries,
)
response.encoding = "utf-8"

return response
Expand Down
4 changes: 3 additions & 1 deletion src/vunnel/providers/nvd/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__( # noqa: PLR0913
overrides_url: str,
logger: logging.Logger | None = None,
download_timeout: int = 125,
download_retry_count: int = 10,
api_key: str | None = None,
overrides_enabled: bool = False,
) -> None:
Expand All @@ -35,14 +36,15 @@ def __init__( # noqa: PLR0913
logger = logging.getLogger(self.__class__.__name__)
self.logger = logger

self.api = NvdAPI(api_key=api_key, logger=logger, timeout=download_timeout)
self.api = NvdAPI(api_key=api_key, logger=logger, timeout=download_timeout, retries=download_retry_count)

self.overrides = NVDOverrides(
enabled=overrides_enabled,
url=overrides_url,
workspace=workspace,
logger=logger,
download_timeout=download_timeout,
retries=download_retry_count,
)

self.urls = [self.api._cve_api_url_]
Expand Down
6 changes: 4 additions & 2 deletions src/vunnel/providers/nvd/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@ class NVDOverrides:
__file_name__ = "nvd-overrides.tar.gz"
__extract_name__ = "nvd-overrides"

def __init__(
def __init__( # noqa: PLR0913
self,
enabled: bool,
url: str,
workspace: Workspace,
logger: logging.Logger | None = None,
download_timeout: int = 125,
retries: int = 5,
):
self.enabled = enabled
self.__url__ = url
self.workspace = workspace
self.download_timeout = download_timeout
self.retries = retries
if not logger:
logger = logging.getLogger(self.__class__.__name__)
self.logger = logger
Expand All @@ -43,7 +45,7 @@ def download(self) -> None:
self.logger.debug("overrides are not enabled, skipping download...")
return

req = http.get(self.__url__, self.logger, stream=True, timeout=self.download_timeout)
req = http.get(self.__url__, self.logger, stream=True, timeout=self.download_timeout, retries=self.retries)

file_path = os.path.join(self.workspace.input_path, self.__file_name__)
with open(file_path, "wb") as fp:
Expand Down
30 changes: 18 additions & 12 deletions src/vunnel/utils/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def get( # noqa: PLR0913
backoff_in_seconds: int = 3,
timeout: int = DEFAULT_TIMEOUT,
status_handler: Optional[Callable[[requests.Response], None]] = None, # noqa: UP007 - python 3.9
max_interval: int = 600,
**kwargs: Any,
) -> requests.Response:
"""
Expand All @@ -45,15 +46,15 @@ def get( # noqa: PLR0913
status_handler= lambda response: None if response.status_code in [200, 201, 405] else response.raise_for_status())
"""
logger.debug(f"http GET {url}")
last_exception: Exception | None = None
sleep_interval = backoff_in_seconds
for attempt in range(retries + 1):
if last_exception:
sleep_interval = backoff_sleep_interval(backoff_in_seconds, attempt - 1, max_value=max_interval)
logger.warning(f"will retry in {int(sleep_interval)} seconds...")
time.sleep(sleep_interval)
sleep_interval = backoff_in_seconds * 2**attempt + random.uniform(0, 1) # noqa: S311
# explanation of S311 disable: rng is not used cryptographically

try:
logger.debug(f"http GET {url} timeout={timeout} retries={retries} backoff={backoff_in_seconds}")
response = requests.get(url, timeout=timeout, **kwargs)
if status_handler:
status_handler(response)
Expand All @@ -62,20 +63,25 @@ def get( # noqa: PLR0913
return response
except requests.exceptions.HTTPError as e:
last_exception = e
will_retry = ""
if attempt < retries:
will_retry = f" (will retry in {int(backoff_in_seconds)} seconds) "
# HTTPError includes the attempted request, so don't include it redundantly here
logger.warning(f"attempt {attempt + 1} of {retries + 1} failed:{will_retry}{e}")
logger.warning(f"attempt {attempt + 1} of {retries + 1} failed: {e}")
except Exception as e:
last_exception = e
will_retry = ""
if attempt < retries:
will_retry = f" (will retry in {int(sleep_interval)} seconds) "
# this is an unexpected exception type, so include the attempted request in case the
# message from the unexpected exception doesn't.
logger.warning(f"attempt {attempt + 1} of {retries + 1}{will_retry}: unexpected exception during GET {url}: {e}")
logger.warning(f"attempt {attempt + 1} of {retries + 1}: unexpected exception during GET {url}: {e}")
if last_exception:
logger.error(f"last retry of GET {url} failed with {last_exception}")
raise last_exception
raise Exception("unreachable")


def backoff_sleep_interval(interval: int, attempt: int, max_value: None | int = None, jitter: bool = True) -> float:
# this is an exponential backoff
val = interval * 2**attempt
if max_value and val > max_value:
val = max_value
if jitter:
val += random.uniform(0, 1) # noqa: S311
# explanation of S311 disable: rng is not used cryptographically
return val
1 change: 1 addition & 0 deletions tests/unit/cli/test-fixtures/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ providers:
nvd:
runtime: *runtime
request_timeout: 20
request_retry_count: 50
overrides_enabled: true
overrides_url: https://github.com/anchore/nvd-data-overrides/SOMEWHEREELSE/main.tar.gz
oracle:
Expand Down
1 change: 1 addition & 0 deletions tests/unit/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def test_config(monkeypatch) -> None:
api_key: secret
overrides_enabled: false
overrides_url: https://github.com/anchore/nvd-data-overrides/archive/refs/heads/main.tar.gz
request_retry_count: 10
request_timeout: 125
runtime:
existing_input: keep
Expand Down
1 change: 1 addition & 0 deletions tests/unit/cli/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def test_full_config(helpers):
nvd=providers.nvd.Config(
runtime=runtime_cfg,
request_timeout=20,
request_retry_count=50,
overrides_enabled=True,
overrides_url="https://github.com/anchore/nvd-data-overrides/SOMEWHEREELSE/main.tar.gz",
),
Expand Down
18 changes: 9 additions & 9 deletions tests/unit/providers/nvd/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_cve_no_api_key(self, simple_mock, mocker):
mocker.call(
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
backoff_in_seconds=30,
retries=10,
params="cveId=CVE-2020-0000",
headers={"content-type": "application/json"},
timeout=1,
Expand All @@ -59,7 +59,7 @@ def test_cve_single_cve(self, simple_mock, mocker):
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="cveId=CVE-2020-0000",
backoff_in_seconds=30,
retries=10,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand Down Expand Up @@ -118,23 +118,23 @@ def test_cve_multi_page(self, mocker):
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="",
backoff_in_seconds=30,
retries=10,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
mocker.call(
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="resultsPerPage=3&startIndex=3",
backoff_in_seconds=30,
retries=10,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
mocker.call(
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="resultsPerPage=3&startIndex=6",
backoff_in_seconds=30,
retries=10,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand All @@ -156,7 +156,7 @@ def test_cve_pub_date_range(self, simple_mock, mocker):
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="pubStartDate=2019-12-04T00:00:00&pubEndDate=2019-12-05T00:00:00",
backoff_in_seconds=30,
retries=10,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand All @@ -178,7 +178,7 @@ def test_cve_last_modified_date_range(self, simple_mock, mocker):
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="lastModStartDate=2019-12-04T00:00:00&lastModEndDate=2019-12-05T00:00:00",
backoff_in_seconds=30,
retries=10,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand All @@ -197,7 +197,7 @@ def test_results_per_page(self, simple_mock, mocker):
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="resultsPerPage=5",
backoff_in_seconds=30,
retries=10,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand All @@ -214,7 +214,7 @@ def test_cve_history(self, simple_mock, mocker):
"https://services.nvd.nist.gov/rest/json/cvehistory/2.0",
subject.logger,
params="cveId=CVE-2020-0000",
backoff_in_seconds=30,
retries=10,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand Down
54 changes: 47 additions & 7 deletions tests/unit/utils/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,14 @@ def test_correct_number_of_retria(self, mock_requests, mock_sleep, mock_logger,

@patch("time.sleep")
@patch("requests.get")
def test_succeeds_if_retries_succeed(self, mock_requests, mock_sleep, mock_logger, error_response, success_response):
@patch("random.uniform")
def test_succeeds_if_retries_succeed(
self, mock_uniform_random, mock_requests, mock_sleep, mock_logger, error_response, success_response
):
mock_uniform_random.side_effect = [0.1]
mock_requests.side_effect = [error_response, success_response]
http.get("http://example.com/some-path", mock_logger, retries=1, backoff_in_seconds=22)
mock_sleep.assert_called_with(22)
mock_sleep.assert_called_with(22.1)
mock_logger.warning.assert_called()
mock_logger.error.assert_not_called()
mock_requests.assert_called_with("http://example.com/some-path", timeout=http.DEFAULT_TIMEOUT)
Expand All @@ -74,7 +78,7 @@ def test_exponential_backoff_and_jitter(
mock_requests.side_effect = [error_response, error_response, error_response, success_response]
mock_uniform_random.side_effect = [0.5, 0.4, 0.1]
http.get("http://example.com/some-path", mock_logger, backoff_in_seconds=10, retries=3)
assert mock_sleep.call_args_list == [call(10), call(10 * 2 + 0.5), call(10 * 4 + 0.4)]
assert mock_sleep.call_args_list == [call(10 + 0.5), call(10 * 2 + 0.4), call(10 * 4 + 0.1)]

@patch("time.sleep")
@patch("requests.get")
Expand All @@ -91,8 +95,13 @@ def test_it_logs_the_url_on_failure(self, mock_requests, mock_sleep, mock_logger
def test_it_log_warns_errors(self, mock_requests, mock_sleep, mock_logger, error_response, success_response):
mock_requests.side_effect = [error_response, success_response]
http.get("http://example.com/some-path", mock_logger, retries=1, backoff_in_seconds=33)
assert "HTTP ERROR" in mock_logger.warning.call_args.args[0]
assert "will retry in 33 seconds" in mock_logger.warning.call_args.args[0]

logged_warnings = [call.args[0] for call in mock_logger.warning.call_args_list]

assert any("HTTP ERROR" in message for message in logged_warnings), "Expected 'HTTP ERROR' in logged warnings."
assert any(
"will retry in 33 seconds" in message for message in logged_warnings
), "Expected retry message in logged warnings."

@patch("time.sleep")
@patch("requests.get")
Expand All @@ -109,16 +118,47 @@ def test_it_calls_status_handler(self, mock_requests, mock_sleep, mock_logger, e

@patch("time.sleep")
@patch("requests.get")
@patch("random.uniform")
def test_it_retries_when_status_handler_raises(
self, mock_requests, mock_sleep, mock_logger, error_response, success_response
self, mock_uniform_random, mock_requests, mock_sleep, mock_logger, error_response, success_response
):
mock_uniform_random.side_effect = [0.25]
mock_requests.side_effect = [success_response, error_response]
status_handler = MagicMock()
status_handler.side_effect = [Exception("custom exception"), None]
result = http.get(
"http://example.com/some-path", mock_logger, status_handler=status_handler, retries=1, backoff_in_seconds=33
)
mock_sleep.assert_called_with(33)
mock_sleep.assert_called_with(33.25)
# custom status handler raised the first time it was called,
# so we expect the second mock response to be returned overall
assert result == error_response


@pytest.mark.parametrize(
"interval, jitter, max_value, expected",
[
(
30, # interval
False, # jitter
None, # max_value
[30, 60, 120, 240, 480, 960, 1920, 3840, 7680, 15360, 30720, 61440, 122880, 245760, 491520], # expected
),
(
3, # interval
False, # jitter
1000, # max_value
[3, 6, 12, 24, 48, 96, 192, 384, 768, 1000, 1000, 1000, 1000, 1000, 1000], # expected
),
],
)
def test_backoff_sleep_interval(interval, jitter, max_value, expected):
actual = [
http.backoff_sleep_interval(interval, attempt, jitter=jitter, max_value=max_value) for attempt in range(len(expected))
]

if not jitter:
assert actual == expected
else:
for i, (a, e) in enumerate(zip(actual, expected)):
assert a >= e and a <= e + 1, f"Jittered value out of bounds at attempt {i}: {a} (expected ~{e})"

0 comments on commit d426c26

Please sign in to comment.