Skip to content

Commit

Permalink
Merge pull request #371 from sedouard/success-response-headers
Browse files Browse the repository at this point in the history
Allow access to request metadata on success
  • Loading branch information
brandur-stripe authored Dec 22, 2017
2 parents 67e6afc + 70dc0f4 commit 2f719d6
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 16 deletions.
6 changes: 4 additions & 2 deletions stripe/api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from stripe import error, oauth_error, http_client, version, util, six
from stripe.multipart_data_generator import MultipartDataGenerator
from stripe.six.moves.urllib.parse import urlencode, urlsplit, urlunsplit
from stripe.stripe_response import StripeResponse


def _encode_datetime(dttime):
Expand Down Expand Up @@ -354,14 +355,15 @@ def interpret_response(self, rbody, rcode, rheaders):
try:
if hasattr(rbody, 'decode'):
rbody = rbody.decode('utf-8')
resp = util.json.loads(rbody)
resp = StripeResponse(rbody, rcode, rheaders)
except Exception:
raise error.APIError(
"Invalid response body from API: %s "
"(HTTP response code was %d)" % (rbody, rcode),
rbody, rcode, rheaders)
if not (200 <= rcode < 300):
self.handle_error_response(rbody, rcode, resp, rheaders)
self.handle_error_response(rbody, rcode, resp.data, rheaders)

return resp

# Deprecated request handling. Will all be removed in 2.0
Expand Down
1 change: 1 addition & 0 deletions stripe/api_resources/abstract/createable_api_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ def create(cls, api_key=None, idempotency_key=None,
url = cls.class_url()
headers = util.populate_headers(idempotency_key)
response, api_key = requestor.request('post', url, params, headers)

return util.convert_to_stripe_object(response, api_key, stripe_version,
stripe_account)
19 changes: 14 additions & 5 deletions stripe/stripe_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ def default(self, obj):
return super(StripeObject.ReprJSONEncoder, self).default(obj)

def __init__(self, id=None, api_key=None, stripe_version=None,
stripe_account=None, **params):
stripe_account=None, last_response=None, **params):
super(StripeObject, self).__init__()

self._unsaved_values = set()
self._transient_values = set()
self._last_response = last_response

self._retrieve_params = params
self._previous = None
Expand All @@ -57,6 +58,10 @@ def __init__(self, id=None, api_key=None, stripe_version=None,
if id:
self['id'] = id

@property
def last_response(self):
return self._last_response

def update(self, update_dict):
for k in update_dict:
self._unsaved_values.add(k)
Expand Down Expand Up @@ -148,22 +153,26 @@ def __reduce__(self):

@classmethod
def construct_from(cls, values, key, stripe_version=None,
stripe_account=None):
stripe_account=None, last_response=None):
instance = cls(values.get('id'), api_key=key,
stripe_version=stripe_version,
stripe_account=stripe_account)
stripe_account=stripe_account,
last_response=last_response)
instance.refresh_from(values, api_key=key,
stripe_version=stripe_version,
stripe_account=stripe_account)
stripe_account=stripe_account,
last_response=last_response)
return instance

def refresh_from(self, values, api_key=None, partial=False,
stripe_version=None, stripe_account=None):
stripe_version=None, stripe_account=None,
last_response=None):
self.api_key = api_key or getattr(values, 'api_key', None)
self.stripe_version = \
stripe_version or getattr(values, 'stripe_version', None)
self.stripe_account = \
stripe_account or getattr(values, 'stripe_account', None)
self._last_response = last_response

# Wipe old state before setting new. This is useful for e.g.
# updating a customer, where there is no persistent card
Expand Down
24 changes: 24 additions & 0 deletions stripe/stripe_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from stripe import util


class StripeResponse:

def __init__(self, body, code, headers):
self.body = body
self.code = code
self.headers = headers
self.data = util.json.loads(body)

@property
def idempotency_key(self):
try:
return self.headers['idempotency-key']
except KeyError:
return None

@property
def request_id(self):
try:
return self.headers['request-id']
except KeyError:
return None
14 changes: 12 additions & 2 deletions stripe/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import stripe
from stripe import six


STRIPE_LOG = os.environ.get('STRIPE_LOG')

logger = logging.getLogger('stripe')
Expand Down Expand Up @@ -232,6 +231,15 @@ def convert_to_stripe_object(resp, api_key=None, stripe_version=None,
load_object_classes()
types = OBJECT_CLASSES.copy()

# If we get a StripeResponse, we'll want to return a
# StripeObject with the last_response field filled out with
# the raw API response information
stripe_response = None

if isinstance(resp, stripe.stripe_response.StripeResponse):
stripe_response = resp
resp = stripe_response.data

if isinstance(resp, list):
return [convert_to_stripe_object(i, api_key, stripe_version,
stripe_account) for i in resp]
Expand All @@ -243,9 +251,11 @@ def convert_to_stripe_object(resp, api_key=None, stripe_version=None,
klass = types.get(klass_name, stripe.stripe_object.StripeObject)
else:
klass = stripe.stripe_object.StripeObject

return klass.construct_from(resp, api_key,
stripe_version=stripe_version,
stripe_account=stripe_account)
stripe_account=stripe_account,
last_response=stripe_response)
else:
return resp

Expand Down
7 changes: 7 additions & 0 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
'source': 'tok_visa'
}

DUMMY_CHARGE_IDEMPOTENT = {
'amount': 100,
'currency': 'usd',
'source': 'tok_visa',
'idempotency_key': '12345'
}

DUMMY_PLAN = {
'amount': 2000,
'interval': 'month',
Expand Down
22 changes: 17 additions & 5 deletions tests/test_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import stripe
from stripe import six
from stripe.stripe_response import StripeResponse
from stripe import util

from tests.helper import StripeUnitTestCase

Expand Down Expand Up @@ -274,15 +276,18 @@ def test_empty_methods(self):
for meth in VALID_API_METHODS:
self.mock_response('{}', 200)

body, key = self.requestor.request(meth, self.valid_path, {})
resp, key = self.requestor.request(meth, self.valid_path, {})

if meth == 'post':
post_data = ''
else:
post_data = None

self.check_call(meth, post_data=post_data)
self.assertEqual({}, body)
self.assertTrue(isinstance(resp, StripeResponse))

self.assertEqual({}, resp.data)
self.assertEqual(util.json.loads(resp.body), resp.data)

def test_methods_with_params_and_response(self):
for meth in VALID_API_METHODS:
Expand All @@ -296,9 +301,16 @@ def test_methods_with_params_and_response(self):
encoded = ('adict%5Bfrobble%5D=bits&adatetime=1356994800&'
'alist%5B%5D=1&alist%5B%5D=2&alist%5B%5D=3')

body, key = self.requestor.request(meth, self.valid_path,
resp, key = self.requestor.request(meth, self.valid_path,
params)
self.assertEqual({'foo': 'bar', 'baz': 6}, body)
self.assertTrue(isinstance(resp, StripeResponse))

self.assertEqual({
'foo': 'bar',
'baz': 6 },
resp.data)
self.assertEqual(util.json.loads(resp.body),
resp.data)

if meth == 'post':
self.check_call(
Expand All @@ -321,7 +333,7 @@ def test_uses_instance_key(self):

self.mock_response('{}', 200, requestor=requestor)

body, used_key = requestor.request('get', self.valid_path, {})
resp, used_key = requestor.request('get', self.valid_path, {})

self.check_call('get', headers=APIHeaderMatcher(
key, request_method='get'), requestor=requestor)
Expand Down
32 changes: 30 additions & 2 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

from mock import patch

from stripe import six
from stripe import six, util

from tests.helper import (StripeTestCase, DUMMY_CHARGE)
from tests.helper import (StripeTestCase, DUMMY_CHARGE, DUMMY_CHARGE_IDEMPOTENT)


class FunctionalTests(StripeTestCase):
Expand Down Expand Up @@ -55,6 +55,7 @@ def test_refresh(self):
charge2.junk = 'junk'
charge2.refresh()
self.assertRaises(AttributeError, lambda: charge2.junk)


def test_list_accessors(self):
customer = stripe.Customer.create(source='tok_visa')
Expand All @@ -75,6 +76,33 @@ def test_response_headers(self):
except stripe.error.CardError as e:
self.assertTrue(e.request_id.startswith('req_'))

def test_success_response_headers(self):
charge = stripe.Charge.create(**DUMMY_CHARGE_IDEMPOTENT)
self.assertTrue(charge.last_response != None)
self.assertTrue(charge.last_response.headers != None)
self.assertEqual(charge.last_response.code, 200)
self.assertEqual(charge.last_response.headers['idempotency-key'], '12345')
self.assertTrue(charge.last_response.headers['request-id'].startswith('req_'))
# ensure helper keys
self.assertEqual(charge.last_response.idempotency_key, '12345')
self.assertTrue(charge.last_response.request_id.startswith('req_'))
# verify the response body is available
parsed_body = util.json.loads(charge.last_response.body)
self.assertEqual(parsed_body['amount'], DUMMY_CHARGE_IDEMPOTENT['amount'])
self.assertEqual(parsed_body['currency'], DUMMY_CHARGE_IDEMPOTENT['currency'])

def test_success_list_response_headers(self):
charges = stripe.Charge.list()
self.assertTrue(charges.last_response != None)
self.assertTrue(charges.last_response.headers != None)
self.assertEqual(charges.last_response.code, 200)
self.assertTrue(charges.last_response.headers['request-id'].startswith('req_'))
self.assertTrue(charges.last_response.request_id.startswith('req_'))
# verify the response body is available
parsed_body = util.json.loads(charges.last_response.body)
self.assertTrue(isinstance(parsed_body['data'], list))
self.assertTrue(parsed_body['object'], 'list')

def test_unicode(self):
# Make sure unicode requests can be sent
self.assertRaises(stripe.error.InvalidRequestError,
Expand Down
50 changes: 50 additions & 0 deletions tests/test_stripe_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import absolute_import, division, print_function

import time

from stripe import util
from stripe.stripe_response import StripeResponse
from tests.helper import StripeUnitTestCase

class StripeResponseTests(StripeUnitTestCase):
def test_idempotency_key(self):
response, headers, body, code = StripeResponseTests.mock_stripe_response()
self.assertEqual(response.idempotency_key, headers['idempotency-key'])

def test_request_id(self):
response, headers, body, code = StripeResponseTests.mock_stripe_response()
self.assertEqual(response.request_id, headers['request-id'])

def test_code(self):
response, headers, body, code = StripeResponseTests.mock_stripe_response()
self.assertEqual(response.code, code)

def test_headers(self):
response, headers, body, code = StripeResponseTests.mock_stripe_response()
self.assertEqual(response.headers, headers)

def test_body(self):
response, headers, body, code = StripeResponseTests.mock_stripe_response()
self.assertEqual(response.body, body)

def test_data(self):
response, headers, body, code = StripeResponseTests.mock_stripe_response()
self.assertEqual(response.data, util.json.loads(body))

@staticmethod
def mock_stripe_response():
code = 200
headers = StripeResponseTests.mock_headers()
body = StripeResponseTests.mock_body()
response = StripeResponse(body, code, headers)
return response, headers, body, code
@staticmethod
def mock_headers():
return {
'idempotency-key': '123456',
'request-id': 'req_123456'
}

@staticmethod
def mock_body():
return '{ "id": "ch_12345", "object": "charge", "amount": 1 }'

0 comments on commit 2f719d6

Please sign in to comment.