Skip to content

Commit

Permalink
Pass 'user_project' if set for BucketACL/BlobACL API requests (google…
Browse files Browse the repository at this point in the history
  • Loading branch information
tseaver authored and landrito committed Aug 22, 2017
1 parent c34ff42 commit da2c11f
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 36 deletions.
25 changes: 24 additions & 1 deletion storage/google/cloud/storage/acl.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class ACL(object):
# as properties).
reload_path = None
save_path = None
user_project = None

def __init__(self):
self.entities = {}
Expand Down Expand Up @@ -405,10 +406,18 @@ def reload(self, client=None):
"""
path = self.reload_path
client = self._require_client(client)
query_params = {}

if self.user_project is not None:
query_params['userProject'] = self.user_project

self.entities.clear()

found = client._connection.api_request(method='GET', path=path)
found = client._connection.api_request(
method='GET',
path=path,
query_params=query_params,
)
self.loaded = True
for entry in found.get('items', ()):
self.add_entity(self.entity_from_dict(entry))
Expand All @@ -435,8 +444,12 @@ def _save(self, acl, predefined, client):
acl = []
query_params[self._PREDEFINED_QUERY_PARAM] = predefined

if self.user_project is not None:
query_params['userProject'] = self.user_project

path = self.save_path
client = self._require_client(client)

result = client._connection.api_request(
method='PATCH',
path=path,
Expand Down Expand Up @@ -532,6 +545,11 @@ def save_path(self):
"""Compute the path for PATCH API requests for this ACL."""
return self.bucket.path

@property
def user_project(self):
"""Compute the user project charged for API requests for this ACL."""
return self.bucket.user_project


class DefaultObjectACL(BucketACL):
"""A class representing the default object ACL for a bucket."""
Expand Down Expand Up @@ -565,3 +583,8 @@ def reload_path(self):
def save_path(self):
"""Compute the path for PATCH API requests for this ACL."""
return self.blob.path

@property
def user_project(self):
"""Compute the user project charged for API requests for this ACL."""
return self.blob.user_project
135 changes: 100 additions & 35 deletions storage/tests/unit/test_acl.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,11 @@ def test_reload_missing(self):
self.assertEqual(list(acl), [])
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'GET')
self.assertEqual(kw[0]['path'], '/testing/acl')
self.assertEqual(kw[0], {
'method': 'GET',
'path': '/testing/acl',
'query_params': {},
})

def test_reload_empty_result_clears_local(self):
ROLE = 'role'
Expand All @@ -543,29 +546,41 @@ def test_reload_empty_result_clears_local(self):
acl.reload_path = '/testing/acl'
acl.loaded = True
acl.entity('allUsers', ROLE)

acl.reload(client=client)

self.assertTrue(acl.loaded)
self.assertEqual(list(acl), [])
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'GET')
self.assertEqual(kw[0]['path'], '/testing/acl')
self.assertEqual(kw[0], {
'method': 'GET',
'path': '/testing/acl',
'query_params': {},
})

def test_reload_nonempty_result(self):
def test_reload_nonempty_result_w_user_project(self):
ROLE = 'role'
USER_PROJECT = 'user-project-123'
connection = _Connection(
{'items': [{'entity': 'allUsers', 'role': ROLE}]})
client = _Client(connection)
acl = self._make_one()
acl.reload_path = '/testing/acl'
acl.loaded = True
acl.user_project = USER_PROJECT

acl.reload(client=client)

self.assertTrue(acl.loaded)
self.assertEqual(list(acl), [{'entity': 'allUsers', 'role': ROLE}])
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'GET')
self.assertEqual(kw[0]['path'], '/testing/acl')
self.assertEqual(kw[0], {
'method': 'GET',
'path': '/testing/acl',
'query_params': {'userProject': USER_PROJECT},
})

def test_save_none_set_none_passed(self):
connection = _Connection()
Expand Down Expand Up @@ -606,30 +621,43 @@ def test_save_no_acl(self):
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'PATCH')
self.assertEqual(kw[0]['path'], '/testing')
self.assertEqual(kw[0]['data'], {'acl': AFTER})
self.assertEqual(kw[0]['query_params'], {'projection': 'full'})

def test_save_w_acl(self):
self.assertEqual(kw[0], {
'method': 'PATCH',
'path': '/testing',
'query_params': {'projection': 'full'},
'data': {'acl': AFTER},
})

def test_save_w_acl_w_user_project(self):
ROLE1 = 'role1'
ROLE2 = 'role2'
STICKY = {'entity': 'allUsers', 'role': ROLE2}
USER_PROJECT = 'user-project-123'
new_acl = [{'entity': 'allUsers', 'role': ROLE1}]
connection = _Connection({'acl': [STICKY] + new_acl})
client = _Client(connection)
acl = self._make_one()
acl.save_path = '/testing'
acl.loaded = True
acl.user_project = USER_PROJECT

acl.save(new_acl, client=client)

entries = list(acl)
self.assertEqual(len(entries), 2)
self.assertTrue(STICKY in entries)
self.assertTrue(new_acl[0] in entries)
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'PATCH')
self.assertEqual(kw[0]['path'], '/testing')
self.assertEqual(kw[0]['data'], {'acl': new_acl})
self.assertEqual(kw[0]['query_params'], {'projection': 'full'})
self.assertEqual(kw[0], {
'method': 'PATCH',
'path': '/testing',
'query_params': {
'projection': 'full',
'userProject': USER_PROJECT,
},
'data': {'acl': new_acl},
})

def test_save_prefefined_invalid(self):
connection = _Connection()
Expand All @@ -652,11 +680,15 @@ def test_save_predefined_valid(self):
self.assertEqual(len(entries), 0)
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'PATCH')
self.assertEqual(kw[0]['path'], '/testing')
self.assertEqual(kw[0]['data'], {'acl': []})
self.assertEqual(kw[0]['query_params'],
{'projection': 'full', 'predefinedAcl': PREDEFINED})
self.assertEqual(kw[0], {
'method': 'PATCH',
'path': '/testing',
'query_params': {
'projection': 'full',
'predefinedAcl': PREDEFINED,
},
'data': {'acl': []},
})

def test_save_predefined_w_XML_alias(self):
PREDEFINED_XML = 'project-private'
Expand All @@ -671,12 +703,15 @@ def test_save_predefined_w_XML_alias(self):
self.assertEqual(len(entries), 0)
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'PATCH')
self.assertEqual(kw[0]['path'], '/testing')
self.assertEqual(kw[0]['data'], {'acl': []})
self.assertEqual(kw[0]['query_params'],
{'projection': 'full',
'predefinedAcl': PREDEFINED_JSON})
self.assertEqual(kw[0], {
'method': 'PATCH',
'path': '/testing',
'query_params': {
'projection': 'full',
'predefinedAcl': PREDEFINED_JSON,
},
'data': {'acl': []},
})

def test_save_predefined_valid_w_alternate_query_param(self):
# Cover case where subclass overrides _PREDEFINED_QUERY_PARAM
Expand All @@ -692,11 +727,15 @@ def test_save_predefined_valid_w_alternate_query_param(self):
self.assertEqual(len(entries), 0)
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'PATCH')
self.assertEqual(kw[0]['path'], '/testing')
self.assertEqual(kw[0]['data'], {'acl': []})
self.assertEqual(kw[0]['query_params'],
{'projection': 'full', 'alternate': PREDEFINED})
self.assertEqual(kw[0], {
'method': 'PATCH',
'path': '/testing',
'query_params': {
'projection': 'full',
'alternate': PREDEFINED,
},
'data': {'acl': []},
})

def test_clear(self):
ROLE1 = 'role1'
Expand All @@ -712,10 +751,12 @@ def test_clear(self):
self.assertEqual(list(acl), [STICKY])
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'PATCH')
self.assertEqual(kw[0]['path'], '/testing')
self.assertEqual(kw[0]['data'], {'acl': []})
self.assertEqual(kw[0]['query_params'], {'projection': 'full'})
self.assertEqual(kw[0], {
'method': 'PATCH',
'path': '/testing',
'query_params': {'projection': 'full'},
'data': {'acl': []},
})


class Test_BucketACL(unittest.TestCase):
Expand All @@ -739,6 +780,15 @@ def test_ctor(self):
self.assertEqual(acl.reload_path, '/b/%s/acl' % NAME)
self.assertEqual(acl.save_path, '/b/%s' % NAME)

def test_user_project(self):
NAME = 'name'
USER_PROJECT = 'user-project-123'
bucket = _Bucket(NAME)
acl = self._make_one(bucket)
self.assertIsNone(acl.user_project)
bucket.user_project = USER_PROJECT
self.assertEqual(acl.user_project, USER_PROJECT)


class Test_DefaultObjectACL(unittest.TestCase):

Expand Down Expand Up @@ -785,9 +835,22 @@ def test_ctor(self):
self.assertEqual(acl.reload_path, '/b/%s/o/%s/acl' % (NAME, BLOB_NAME))
self.assertEqual(acl.save_path, '/b/%s/o/%s' % (NAME, BLOB_NAME))

def test_user_project(self):
NAME = 'name'
BLOB_NAME = 'blob-name'
USER_PROJECT = 'user-project-123'
bucket = _Bucket(NAME)
blob = _Blob(bucket, BLOB_NAME)
acl = self._make_one(blob)
self.assertIsNone(acl.user_project)
blob.user_project = USER_PROJECT
self.assertEqual(acl.user_project, USER_PROJECT)


class _Blob(object):

user_project = None

def __init__(self, bucket, blob):
self.bucket = bucket
self.blob = blob
Expand All @@ -799,6 +862,8 @@ def path(self):

class _Bucket(object):

user_project = None

def __init__(self, name):
self.name = name

Expand Down

0 comments on commit da2c11f

Please sign in to comment.