diff --git a/stripe/__init__.py b/stripe/__init__.py index c3d561b0a..c0f14086e 100644 --- a/stripe/__init__.py +++ b/stripe/__init__.py @@ -19,6 +19,7 @@ proxy = None default_http_client = None app_info = None +max_network_retries = 0 # Set to either 'debug' or 'info', controls console logging log = None diff --git a/stripe/api_requestor.py b/stripe/api_requestor.py index 8625835c5..7a9fe1d54 100644 --- a/stripe/api_requestor.py +++ b/stripe/api_requestor.py @@ -5,6 +5,7 @@ import json import platform import time +import uuid import stripe from stripe import error, oauth_error, http_client, version, util, six @@ -220,6 +221,7 @@ def request_headers(self, api_key, method): if method == 'post': headers['Content-Type'] = 'application/x-www-form-urlencoded' + headers.setdefault('Idempotency-Key', str(uuid.uuid4())) if self.api_version is not None: headers['Stripe-Version'] = self.api_version @@ -271,7 +273,6 @@ def request_raw(self, method, url, params=None, supplied_headers=None): 'assistance.' % (method,)) headers = self.request_headers(my_api_key, method) - if supplied_headers is not None: for key, value in six.iteritems(supplied_headers): headers[key] = value @@ -281,7 +282,7 @@ def request_raw(self, method, url, params=None, supplied_headers=None): 'Post details', post_data=encoded_params, api_version=self.api_version) - rbody, rcode, rheaders = self._client.request( + rbody, rcode, rheaders = self._client.request_with_retries( method, abs_url, headers, post_data) util.log_info( diff --git a/stripe/error.py b/stripe/error.py index 66d8a648f..ae0830d1a 100644 --- a/stripe/error.py +++ b/stripe/error.py @@ -53,7 +53,12 @@ class APIError(StripeError): class APIConnectionError(StripeError): - pass + def __init__(self, message, http_body=None, http_status=None, + json_body=None, headers=None, code=None, should_retry=False): + super(APIConnectionError, self).__init__(message, http_body, + http_status, + json_body, headers, code) + self.should_retry = should_retry class StripeErrorWithParamCode(StripeError): diff --git a/stripe/http_client.py b/stripe/http_client.py index f0a590c53..504eec657 100644 --- a/stripe/http_client.py +++ b/stripe/http_client.py @@ -5,8 +5,10 @@ import textwrap import warnings import email +import time +import random -from stripe import error, util, six +from stripe import error, util, six, max_network_retries # - Requests is the preferred HTTP library # - Google App Engine has urlfetch @@ -77,6 +79,9 @@ def new_default_http_client(*args, **kwargs): class HTTPClient(object): + MAX_DELAY = 2 + INITIAL_DELAY = 0.5 + def __init__(self, verify_ssl_certs=True, proxy=None): self._verify_ssl_certs = verify_ssl_certs if proxy: @@ -89,10 +94,74 @@ def __init__(self, verify_ssl_certs=True, proxy=None): " ""https"" and/or ""http"" keys.") self._proxy = proxy.copy() if proxy else None + def request_with_retries(self, method, url, headers, post_data=None): + num_retries = 0 + + while True: + try: + num_retries += 1 + response = self.request(method, url, headers, post_data) + connection_error = None + except error.APIConnectionError as e: + connection_error = e + response = None + + if self._should_retry(response, connection_error, num_retries): + if connection_error: + util.log_info("Encountered a retryable error %s" % + connection_error.user_message) + + sleep_time = self._sleep_time_seconds(num_retries) + util.log_info(("Initiating retry %i for request %s %s after " + "sleeping %.2f seconds." % + (num_retries, method, url, sleep_time))) + time.sleep(sleep_time) + else: + if response is not None: + return response + else: + raise connection_error + def request(self, method, url, headers, post_data=None): raise NotImplementedError( 'HTTPClient subclasses must implement `request`') + def _should_retry(self, response, api_connection_error, num_retries): + if response is not None: + _, status_code, _ = response + should_retry = status_code == 409 + else: + # We generally want to retry on timeout and connection + # exceptions, but defer this decision to underlying subclass + # implementations. They should evaluate the driver-specific + # errors worthy of retries, and set flag on the error returned. + should_retry = api_connection_error.should_retry + return should_retry and num_retries < self._max_network_retries() + + def _max_network_retries(self): + # Configured retries, isolated here for tests + return max_network_retries + + def _sleep_time_seconds(self, num_retries): + # Apply exponential backoff with initial_network_retry_delay on the + # number of num_retries so far as inputs. + # Do not allow the number to exceed max_network_retry_delay. + sleep_seconds = min( + HTTPClient.INITIAL_DELAY * (2 ** (num_retries - 1)), + HTTPClient.MAX_DELAY) + + sleep_seconds = self._add_jitter_time(sleep_seconds) + + # But never sleep less than the base sleep seconds. + sleep_seconds = max(HTTPClient.INITIAL_DELAY, sleep_seconds) + return sleep_seconds + + def _add_jitter_time(self, sleep_seconds): + # Randomize the value in [(sleep_seconds/ 2) to (sleep_seconds)] + # Also separated method here to isolate randomness for tests + sleep_seconds *= (0.5 * (1 + random.uniform(0, 1))) + return sleep_seconds + def close(self): raise NotImplementedError( 'HTTPClient subclasses must implement `close`') @@ -146,11 +215,31 @@ def request(self, method, url, headers, post_data=None): return content, status_code, result.headers def _handle_request_error(self, e): - if isinstance(e, requests.exceptions.RequestException): + + # Catch SSL error first as it belongs to ConnectionError, + # but we don't want to retry + if isinstance(e, requests.exceptions.SSLError): + msg = ("Could not verify Stripe's SSL certificate. Please make " + "sure that your network is not intercepting certificates. " + "If this problem persists, let us know at " + "support@stripe.com.") + err = "%s: %s" % (type(e).__name__, str(e)) + should_retry = False + # Retry only timeout and connect errors; similar to urllib3 Retry + elif isinstance(e, requests.exceptions.Timeout) or \ + isinstance(e, requests.exceptions.ConnectionError): msg = ("Unexpected error communicating with Stripe. " "If this problem persists, let us know at " "support@stripe.com.") err = "%s: %s" % (type(e).__name__, str(e)) + should_retry = True + # Catch remaining request exceptions + elif isinstance(e, requests.exceptions.RequestException): + msg = ("Unexpected error communicating with Stripe. " + "If this problem persists, let us know at " + "support@stripe.com.") + err = "%s: %s" % (type(e).__name__, str(e)) + should_retry = False else: msg = ("Unexpected error communicating with Stripe. " "It looks like there's probably a configuration " @@ -161,8 +250,10 @@ def _handle_request_error(self, e): err += " with error message %s" % (str(e),) else: err += " with no error message" + should_retry = False + msg = textwrap.fill(msg) + "\n\n(Network error: %s)" % (err,) - raise error.APIConnectionError(msg) + raise error.APIConnectionError(msg, should_retry=should_retry) def close(self): if self._session is not None: diff --git a/tests/test_api_requestor.py b/tests/test_api_requestor.py index 98e1347cf..ab1190004 100644 --- a/tests/test_api_requestor.py +++ b/tests/test_api_requestor.py @@ -3,6 +3,7 @@ import datetime import json import tempfile +import uuid import pytest @@ -35,28 +36,32 @@ class APIHeaderMatcher(object): 'User-Agent', 'X-Stripe-Client-User-Agent', ] - METHOD_EXTRA_KEYS = {"post": ["Content-Type"]} + METHOD_EXTRA_KEYS = {"post": ["Content-Type", "Idempotency-Key"]} def __init__(self, api_key=None, extra={}, request_method=None, - user_agent=None, app_info=None): + user_agent=None, app_info=None, idempotency_key=None): self.request_method = request_method self.api_key = api_key or stripe.api_key self.extra = extra self.user_agent = user_agent self.app_info = app_info + self.idempotency_key = idempotency_key def __eq__(self, other): return (self._keys_match(other) and self._auth_match(other) and self._user_agent_match(other) and self._x_stripe_ua_contains_app_info(other) and + self._idempotency_key_match(other) and self._extra_match(other)) def __repr__(self): return ("APIHeaderMatcher(request_method=%s, api_key=%s, extra=%s, " - "user_agent=%s, app_info=%s)" % + "user_agent=%s, app_info=%s, idempotency_key=%s)" % (repr(self.request_method), repr(self.api_key), - repr(self.extra), repr(self.user_agent), repr(self.app_info))) + repr(self.extra), repr(self.user_agent), repr(self.app_info), + repr(self.idempotency_key)) + ) def _keys_match(self, other): expected_keys = list(set(self.EXP_KEYS + list(self.extra.keys()))) @@ -74,6 +79,11 @@ def _user_agent_match(self, other): return True + def _idempotency_key_match(self, other): + if self.idempotency_key is not None: + return other['Idempotency-Key'] == self.idempotency_key + return True + def _x_stripe_ua_contains_app_info(self, other): if self.app_info: ua = json.loads(other['X-Stripe-Client-User-Agent']) @@ -129,6 +139,19 @@ def __repr__(self): return ("UrlMatcher(exp_parts=%s)" % (repr(self.exp_parts))) +class AnyUUID4Matcher(object): + + def __eq__(self, other): + try: + uuid.UUID(other, version=4) + except ValueError: + return False + return True + + def __repr__(self): + return "AnyUUID4Matcher()" + + class TestAPIRequestor(object): ENCODE_INPUTS = { 'dict': { @@ -198,7 +221,7 @@ def requestor(self, http_client): def mock_response(self, mocker, http_client): def mock_response(return_body, return_code, headers=None): print(return_code) - http_client.request = mocker.Mock( + http_client.request_with_retries = mocker.Mock( return_value=(return_body, return_code, headers or {})) return mock_response @@ -211,7 +234,7 @@ def check_call(method, abs_url=None, headers=None, if not headers: headers = APIHeaderMatcher(request_method=method) - http_client.request.assert_called_with( + http_client.request_with_retries.assert_called_with( method, abs_url, headers, post_data) return check_call @@ -417,6 +440,31 @@ def test_uses_app_info(self, requestor, mock_response, check_call): finally: stripe.app_info = old + def test_uses_given_idempotency_key(self, requestor, mock_response, + check_call): + mock_response('{}', 200) + meth = 'post' + requestor.request(meth, self.valid_path, {}, + {'Idempotency-Key': '123abc'}) + + header_matcher = APIHeaderMatcher( + request_method=meth, + idempotency_key='123abc' + ) + check_call(meth, headers=header_matcher, post_data='') + + def test_uuid4_idempotency_key_when_not_given(self, requestor, + mock_response, check_call): + mock_response('{}', 200) + meth = 'post' + requestor.request(meth, self.valid_path, {}) + + header_matcher = APIHeaderMatcher( + request_method=meth, + idempotency_key=AnyUUID4Matcher() + ) + check_call(meth, headers=header_matcher, post_data='') + def test_fails_without_api_key(self, requestor): stripe.api_key = None @@ -535,12 +583,12 @@ def test_default_http_client_called(self, mocker): hc = mocker.Mock(stripe.http_client.HTTPClient) hc._verify_ssl_certs = True hc.name = 'mockclient' - hc.request = mocker.Mock(return_value=("{}", 200, {})) + hc.request_with_retries = mocker.Mock(return_value=("{}", 200, {})) stripe.default_http_client = hc stripe.Charge.list(limit=3) - hc.request.assert_called_with( + hc.request_with_retries.assert_called_with( 'get', 'https://api.stripe.com/v1/charges?limit=3', mocker.ANY, diff --git a/tests/test_error.py b/tests/test_error.py index 00f072cde..3f5d9d10a 100644 --- a/tests/test_error.py +++ b/tests/test_error.py @@ -62,3 +62,12 @@ def test_repr(self): assert repr(err) == \ "CardError(message='öre', param='cparam', code='ccode', " \ "http_status=403, request_id='123')" + + +class TestApiConnectionError(object): + def test_default_no_retry(self): + err = error.APIConnectionError('msg') + assert err.should_retry is False + + err = error.APIConnectionError('msg', should_retry=True) + assert err.should_retry diff --git a/tests/test_http_client.py b/tests/test_http_client.py index 6271bef20..f58acd434 100644 --- a/tests/test_http_client.py +++ b/tests/test_http_client.py @@ -52,6 +52,108 @@ def test_new_default_http_client_urllib2(self): stripe.http_client.Urllib2Client) +class TestRetrySleepTimeDefaultHttpClient(StripeClientTestCase): + from contextlib import contextmanager + + def assert_sleep_times(self, client, expected): + until = len(expected) + actual = list( + map(lambda i: client._sleep_time_seconds(i + 1), range(until))) + assert expected == actual + + @contextmanager + def mock_max_delay(self, new_value): + original_value = stripe.http_client.HTTPClient.MAX_DELAY + stripe.http_client.HTTPClient.MAX_DELAY = new_value + try: + yield self + finally: + stripe.http_client.HTTPClient.MAX_DELAY = original_value + + def test_sleep_time_exponential_back_off(self): + client = stripe.http_client.new_default_http_client() + client._add_jitter_time = lambda t: t + with self.mock_max_delay(10): + self.assert_sleep_times(client, [0.5, 1.0, 2.0, 4.0, 8.0]) + + def test_initial_delay_as_minimum(self): + client = stripe.http_client.new_default_http_client() + client._add_jitter_time = lambda t: t * 0.001 + initial_delay = stripe.http_client.HTTPClient.INITIAL_DELAY + self.assert_sleep_times(client, [initial_delay] * 5) + + def test_maximum_delay(self): + client = stripe.http_client.new_default_http_client() + client._add_jitter_time = lambda t: t + max_delay = stripe.http_client.HTTPClient.MAX_DELAY + expected = [0.5, 1.0, max_delay, max_delay, max_delay] + self.assert_sleep_times(client, expected) + + def test_randomness_added(self): + client = stripe.http_client.new_default_http_client() + random_value = 0.8 + client._add_jitter_time = lambda t: t * random_value + base_value = stripe.http_client.HTTPClient.INITIAL_DELAY * random_value + + with self.mock_max_delay(10): + expected = [stripe.http_client.HTTPClient.INITIAL_DELAY, + base_value * 2, + base_value * 4, + base_value * 8, + base_value * 16] + self.assert_sleep_times(client, expected) + + def test_jitter_has_randomness_but_within_range(self): + client = stripe.http_client.new_default_http_client() + + jittered_ones = set( + map(lambda _: client._add_jitter_time(1), list(range(100))) + ) + + assert len(jittered_ones) > 1 + assert all(0.5 <= val <= 1 for val in jittered_ones) + + +class TestRetryConditionsDefaultHttpClient(StripeClientTestCase): + + def test_should_retry_on_codes(self): + one_xx = list(range(100, 104)) + two_xx = list(range(200, 209)) + three_xx = list(range(300, 308)) + four_xx = list(range(400, 431)) + five_xx = list(range(500, 512)) + + client = stripe.http_client.new_default_http_client() + codes = one_xx + two_xx + three_xx + four_xx + five_xx + codes.remove(409) + + for code in codes: + assert client._should_retry((None, code, None), None, 1) is False + + def test_should_retry_on_error(self, mocker): + client = stripe.http_client.new_default_http_client() + client._max_network_retries = lambda: 1 + api_connection_error = mocker.Mock() + + api_connection_error.should_retry = True + assert client._should_retry(None, api_connection_error, 0) is True + + api_connection_error.should_retry = False + assert client._should_retry(None, api_connection_error, 0) is False + + def test_should_retry_on_num_retries(self, mocker): + client = stripe.http_client.new_default_http_client() + max_test_retries = 10 + client._max_network_retries = lambda: max_test_retries + api_connection_error = mocker.Mock() + api_connection_error.should_retry = True + + assert client._should_retry( + None, api_connection_error, max_test_retries + 1) is False + assert client._should_retry( + (None, 409, None), None, max_test_retries + 1) is False + + class ClientTestBase(object): @pytest.fixture def request_mock(self, request_mocks): @@ -63,7 +165,7 @@ def valid_url(self, path='/foo'): def make_request(self, method, url, headers, post_data): client = self.REQUEST_CLIENT(verify_ssl_certs=True) - return client.request(method, url, headers, post_data) + return client.request_with_retries(method, url, headers, post_data) @pytest.fixture def mock_response(self): @@ -139,8 +241,9 @@ def mock_response(mock, body, code): @pytest.fixture def mock_error(self, mocker, session): def mock_error(mock): - mock.exceptions.RequestException = Exception - session.request.side_effect = mock.exceptions.RequestException() + # The first kind of request exceptions we catch + mock.exceptions.SSLError = Exception + session.request.side_effect = mock.exceptions.SSLError() mock.Session = mocker.MagicMock(return_value=session) return mock_error @@ -149,22 +252,26 @@ def mock_error(mock): # session. @pytest.fixture def check_call(self, session): - def check_call(mock, method, url, post_data, headers, timeout=80): - session.request. \ - assert_called_with(method, url, - headers=headers, - data=post_data, - verify=RequestsVerify(), - proxies={"http": "http://slap/", - "https": "http://slap/"}, - timeout=timeout) + def check_call(mock, method, url, post_data, headers, timeout=80, + times=None): + times = times or 1 + args = (method, url) + kwargs = {'headers': headers, + 'data': post_data, + 'verify': RequestsVerify(), + 'proxies': {"http": "http://slap/", + "https": "http://slap/"}, + 'timeout': timeout} + calls = [(args, kwargs) for _ in range(times)] + session.request.assert_has_calls(calls) + return check_call def make_request(self, method, url, headers, post_data, timeout=80): client = self.REQUEST_CLIENT(verify_ssl_certs=True, timeout=timeout, proxy='http://slap/') - return client.request(method, url, headers, post_data) + return client.request_with_retries(method, url, headers, post_data) def test_timeout(self, request_mock, mock_response, check_call): headers = {'my-header': 'header val'} @@ -176,6 +283,151 @@ def test_timeout(self, request_mock, mock_response, check_call): check_call(None, 'POST', self.valid_url, data, headers, timeout=5) +class TestRequestClientRetryBehavior(TestRequestsClient): + + @pytest.fixture + def response(self, mocker): + def response(code=200): + result = mocker.Mock() + result.content = '{}' + result.status_code = code + return result + return response + + @pytest.fixture + def mock_retry(self, mocker, session, request_mock): + def mock_retry(retry_error_num=0, + no_retry_error_num=0, + responses=[]): + + # Mocking classes of exception we catch. Any group of exceptions + # with the same inheritance pattern will work + request_root_error_class = Exception + request_mock.exceptions.RequestException = request_root_error_class + + no_retry_parent_class = LookupError + no_retry_child_class = KeyError + request_mock.exceptions.SSLError = no_retry_parent_class + no_retry_errors = [no_retry_child_class()] * no_retry_error_num + + retry_parent_class = EnvironmentError + retry_child_class = IOError + request_mock.exceptions.Timeout = retry_parent_class + request_mock.exceptions.ConnectionError = retry_parent_class + retry_errors = [retry_child_class()] * retry_error_num + + # Include mock responses as possible side-effects + # to simulate returning proper results after some exceptions + session.request.side_effect = retry_errors + no_retry_errors + \ + responses + + request_mock.Session = mocker.MagicMock(return_value=session) + return request_mock + return mock_retry + + @pytest.fixture + def check_call_numbers(self, check_call): + valid_url = self.valid_url + + def check_call_numbers(times): + check_call(None, 'GET', valid_url, None, {}, times=times) + return check_call_numbers + + def max_retries(self): + return 3 + + def make_request(self): + client = self.REQUEST_CLIENT(verify_ssl_certs=True, + timeout=80, + proxy='http://slap/') + # Override sleep time to speed up tests + client._sleep_time = lambda _: 0.0001 + # Override configured max retries + client._max_network_retries = lambda: self.max_retries() + return client.request_with_retries('GET', self.valid_url, {}, None) + + def test_retry_error_until_response(self, mock_retry, response, + check_call_numbers): + mock_retry(retry_error_num=1, responses=[response(code=202)]) + _, code, _ = self.make_request() + assert code == 202 + check_call_numbers(2) + + def test_retry_error_until_exceeded(self, mock_retry, response, + check_call_numbers): + mock_retry(retry_error_num=self.max_retries()) + with pytest.raises(stripe.error.APIConnectionError): + self.make_request() + + check_call_numbers(self.max_retries()) + + def test_no_retry_error(self, mock_retry, response, check_call_numbers): + mock_retry(no_retry_error_num=self.max_retries()) + with pytest.raises(stripe.error.APIConnectionError): + self.make_request() + check_call_numbers(1) + + def test_retry_codes(self, mock_retry, response, check_call_numbers): + mock_retry(responses=[response(code=409), response(code=202)]) + _, code, _ = self.make_request() + assert code == 202 + check_call_numbers(2) + + def test_retry_codes_until_exceeded(self, mock_retry, response, + check_call_numbers): + mock_retry(responses=[response(code=409)] * self.max_retries()) + _, code, _ = self.make_request() + assert code == 409 + check_call_numbers(self.max_retries()) + + @pytest.fixture + def connection_error(self, session): + client = self.REQUEST_CLIENT() + + def connection_error(given_exception): + with pytest.raises(stripe.error.APIConnectionError) as error: + client._handle_request_error(given_exception) + return error.value + return connection_error + + def test_handle_request_error_should_retry(self, connection_error, + mock_retry): + request_mock = mock_retry() + + error = connection_error(request_mock.exceptions.Timeout()) + assert error.should_retry + + error = connection_error(request_mock.exceptions.ConnectionError()) + assert error.should_retry + + def test_handle_request_error_should_not_retry(self, connection_error, + mock_retry): + request_mock = mock_retry() + + error = connection_error(request_mock.exceptions.SSLError()) + assert error.should_retry is False + assert 'not verify Stripe\'s SSL certificate' in error.user_message + + error = connection_error(request_mock.exceptions.RequestException()) + assert error.should_retry is False + + # Mimic non-requests exception as not being children of Exception, + # See mock_retry for the exceptions setup + error = connection_error(BaseException("")) + assert error.should_retry is False + assert 'configuration issue locally' in error.user_message + + # Skip inherited basic requests client tests + def test_request(self, request_mock, mock_response, check_call): + pass + + def test_exception(self, request_mock, mock_error): + pass + + def test_timeout(self, request_mock, mock_response, check_call): + pass + + class TestUrlFetchClient(StripeClientTestCase, ClientTestBase): REQUEST_CLIENT = stripe.http_client.UrlFetchClient @@ -217,7 +469,8 @@ def make_request(self, method, url, headers, post_data, proxy=None): self.client = self.REQUEST_CLIENT(verify_ssl_certs=True, proxy=proxy) self.proxy = proxy - return self.client.request(method, url, headers, post_data) + return self.client.request_with_retries(method, url, headers, + post_data) @pytest.fixture def mock_response(self, mocker): @@ -292,7 +545,8 @@ def make_request(self, method, url, headers, post_data, proxy=None): self.client = self.REQUEST_CLIENT(verify_ssl_certs=True, proxy=proxy) self.proxy = proxy - return self.client.request(method, url, headers, post_data) + return self.client.request_with_retries(method, url, headers, + post_data) @pytest.fixture def curl_mock(self, mocker):