Skip to content

Commit

Permalink
fix(add-bearer-token-header) add authorization header for Bearer token
Browse files Browse the repository at this point in the history
for submission
  • Loading branch information
Andy Gee committed May 4, 2021
1 parent e33a137 commit dd55435
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 11 deletions.
36 changes: 35 additions & 1 deletion docs/howto/devTest.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
## Dev-Test

### Set up Python Virtual Environment

You can set up a Python development environment with a virtual environment:

```bash
python3 -m venv py3
```

Make sure that you have the virtual environment activated:

```bash
. py3/bin/activate
```

### Install poetry

To use the latest code in this repo (or to develop new features) you can clone this repo, install `poetry`:
Expand All @@ -22,11 +36,31 @@ Local development like this:
```
poetry shell
poetry install -vv
python -m pytest
python3 -m pytest
```

There are various ways to select a subset of python unit-tests - see: https://stackoverflow.com/questions/36456920/is-there-a-way-to-specify-which-pytest-tests-to-run-from-a-file

### Manual Testing

You can also set up credentials to submit data to the graph in your data commons. This assumes that you can get API access by downloading your [credentials.json](https://gen3.org/resources/user/using-api/#credentials-to-send-api-requests).

> Make sure that your python virtual environment and dependencies are updated. Also, check that your credentials have appropriate permissions to make the service calls too.
```python
COMMONS_URL = "https://mycommons.azurefd.net"
PROGRAM_NAME = "MyProgram"
PROJECT_NAME = "MyProject"
CREDENTIALS_FILE_PATH = "credentials.json"
gen3_node_json = {
"projects": {"code": PROJECT_NAME},
"type": "core_metadata_collection",
"submitter_id": "core_metadata_collection_myid123456",
}
auth = Gen3Auth(endpoint=COMMONS_URL, refresh_file=CREDENTIALS_FILE_PATH)
sheepdog_client = Gen3Submission(COMMONS_URL, auth)
json_result = sheepdog_client.submit_record(PROGRAM_NAME, PROJECT_NAME, gen3_node_json)
```

### Smoke test

Most of the SDK functionality requires a backend Gen3 environment
Expand Down
29 changes: 24 additions & 5 deletions gen3/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import time
import logging
from urllib.parse import urlparse
import backoff

from gen3.utils import raise_for_status
from gen3.utils import DEFAULT_BACKOFF_SETTINGS, raise_for_status


class Gen3AuthError(Exception):
Expand Down Expand Up @@ -267,18 +268,34 @@ def refresh_access_token(self):
cache_file = token_cache_file(
self._refresh_token and self._refresh_token["api_key"] or self._wts_idp
)

try:
self._write_to_file(cache_file, self._access_token)
except Exception as e:
logging.warning(
f"Exceeded number of retries, unable to write to cache file."
)

return self._access_token

@backoff.on_exception(
wait_gen=backoff.expo, exception=Exception, **DEFAULT_BACKOFF_SETTINGS
)
def _write_to_file(self, cache_file, content):
# write a temp file, then rename - to avoid
# simultaneous writes to same file race condition
temp = cache_file + (
".tmp_eraseme_%d_%d" % (random.randrange(100000), time.time())
)
try:
with open(temp, "w") as f:
f.write(self._access_token)
f.write(content)
os.rename(temp, cache_file)
except:
return True
except Exception as e:
logging.warning("failed to write token cache file: " + cache_file)
return self._access_token
logging.warning(str(e))
raise e

def get_access_token(self):
""" Get the access token - auto refresh if within 5 minutes of expiration """
Expand All @@ -291,10 +308,12 @@ def get_access_token(self):
with open(cache_file) as f:
self._access_token = f.read()
self._access_token_info = decode_token(self._access_token)
except:
except Exception as e:
logging.warning("ignoring invalid token cache: " + cache_file)
self._access_token = None
self._access_token_info = None
logging.warning(str(e))

need_new_token = (
not self._access_token
or not self._access_token_info
Expand Down
5 changes: 4 additions & 1 deletion gen3/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import requests
import pandas as pd
import os
import logging

from gen3.utils import raise_for_status

Expand Down Expand Up @@ -198,8 +199,10 @@ def submit_record(self, program, project, json):
"""
api_url = "{}/api/v0/submission/{}/{}".format(self._endpoint, program, project)
logging.info("\nUsing the Sheepdog API URL {}\n".format(api_url))

output = requests.put(api_url, auth=self._auth_provider, json=json)
raise_for_status(output)
output.raise_for_status()
return output.json()

def delete_record(self, program, project, uuid):
Expand Down
3 changes: 3 additions & 0 deletions gen3/tools/indexing/index_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ def index_object_manifest(
auth(Gen3Auth): Gen3 auth or tuple with basic auth name and password
replace_urls(bool): flag to indicate if replace urls or not
manifest_file_delimiter(str): manifest's delimiter
output_filename(str): output file name for manifest
Returns:
files(list(dict)): list of file info
Expand All @@ -520,6 +521,8 @@ def index_object_manifest(
if not commons_url.endswith(service_location):
commons_url += "/" + service_location

logging.info("\nUsing URL {}\n".format(commons_url))

indexclient = client.IndexClient(commons_url, "v0", auth=auth)

# if delimter not specified, try to get based on file ext
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
from gen3.index import Gen3Index
from gen3.submission import Gen3Submission
from gen3.query import Gen3Query
from gen3.auth import Gen3Auth
import pytest
from drsclient.client import DrsClient
from unittest.mock import call, MagicMock, patch


class MockAuth:
def __init__(self):
self.endpoint = "https://example.commons.com"
self.refresh_token = {"api_key": "123"}


@pytest.fixture
Expand All @@ -26,6 +29,17 @@ def gen3_auth():
return MockAuth()


@pytest.fixture
def mock_gen3_auth():
mock_auth = MockAuth()
# patch as __init__ has method call
with patch("gen3.auth.endpoint_from_token") as mock_endpoint_from_token:
mock_endpoint_from_token().return_value = mock_auth.endpoint
return Gen3Auth(
endpoint=mock_auth.endpoint, refresh_token=mock_auth.refresh_token
)


# for unittest with mock server
@pytest.fixture
def index_client(indexd_server):
Expand Down
97 changes: 97 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,103 @@ def test_token_cache():
assert cache_file == expected


def test_refresh_access_token(mock_gen3_auth):
"""
Make sure that access token ends up in header when refresh is called
"""
with patch("gen3.auth.get_access_token_with_key") as mock_access_token:
mock_access_token.return_value = "new_access_token"
with patch("gen3.auth.decode_token") as mock_decode_token:
mock_decode_token().return_value = {"aud": "123"}
with patch("gen3.auth.Gen3Auth._write_to_file") as mock_write_to_file:
mock_write_to_file().return_value = True
with patch(
"gen3.auth.Gen3Auth.__call__",
return_value=MagicMock(
headers={"Authorization": "Bearer new_access_token"}
),
) as mock_call:
access_token = mock_gen3_auth.refresh_access_token()
assert (
"Bearer " + access_token == mock_call().headers["Authorization"]
)


def test_refresh_access_token_no_cache_file(mock_gen3_auth):
"""
Make sure that access token ends up in header when refresh is called after failing to write to cache file
"""
with patch("gen3.auth.get_access_token_with_key") as mock_access_token:
mock_access_token.return_value = "new_access_token"
with patch("gen3.auth.decode_token") as mock_decode_token:
mock_decode_token().return_value = {"aud": "123"}
with patch("gen3.auth.Gen3Auth._write_to_file") as mock_write_to_file:
mock_write_to_file().return_value = False
with patch(
"gen3.auth.Gen3Auth.__call__",
return_value=MagicMock(
headers={"Authorization": "Bearer new_access_token"}
),
) as mock_call:
access_token = mock_gen3_auth.refresh_access_token()
assert (
"Bearer " + access_token == mock_call().headers["Authorization"]
)


def test_write_to_file_success(mock_gen3_auth):
"""
Make sure that you can write content to a file
"""
with patch("builtins.open", create=True) as mock_open_file:
mock_open_file.return_value = MagicMock()
with patch("builtins.open.write") as mock_file_write:
mock_file_write.return_value = True
with patch("os.rename") as mock_os_rename:
mock_os_rename.return_value = True
result = mock_gen3_auth._write_to_file("some_file", "content")
assert result == True


def test_write_to_file_permission_error(mock_gen3_auth):
"""
Check that the file isn't written when there's a PermissionError
"""
with patch("builtins.open", create=True) as mock_open_file:
mock_open_file.return_value = MagicMock()
with patch(
"builtins.open.write", side_effect=PermissionError
) as mock_file_write:
with pytest.raises(FileNotFoundError):
result = mock_gen3_auth._write_to_file("some_file", "content")


def test_write_to_file_rename_permission_error(mock_gen3_auth):
"""
Check that the file isn't written when there's a PermissionError for renaming
"""
with patch("builtins.open", create=True) as mock_open_file:
mock_open_file.return_value = MagicMock()
with patch("builtins.open.write") as mock_file_write:
mock_file_write.return_value = True
with patch("os.rename", side_effect=PermissionError) as mock_os_rename:
with pytest.raises(PermissionError):
result = mock_gen3_auth._write_to_file("some_file", "content")


def test_write_to_file_rename_file_not_found_error(mock_gen3_auth):
"""
Check that the file isn't renamed when there's a FileNotFoundError
"""
with patch("builtins.open", create=True) as mock_open_file:
mock_open_file.return_value = MagicMock()
with patch("builtins.open.write") as mock_file_write:
mock_file_write.return_value = True
with patch("os.rename", side_effect=FileNotFoundError) as mock_os_rename:
with pytest.raises(FileNotFoundError):
result = mock_gen3_auth._write_to_file("some_file", "content")


def test_auth_init_outside_workspace():
"""
Test that a Gen3Auth instance can be initialized when the
Expand Down
70 changes: 66 additions & 4 deletions tests/test_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,12 @@ def test_open_project(sub):


def test_submit_record(sub):
with patch("gen3.submission.requests") as mock_request:
mock_request.status_code = 200
mock_request.json.return_value = '{ "key": "value" }'
"""
Make sure that you can submit a record
"""
with patch("gen3.submission.requests.put") as mock_request:
mock_request().status_code = 200
mock_request().json.return_value = '{ "key": "value" }'
rec = sub.submit_record(
"prog1",
"proj1",
Expand All @@ -112,7 +115,66 @@ def test_submit_record(sub):
"type": "experiment",
},
)
assert rec
assert rec == mock_request().json.return_value


def test_submit_record_include_refresh_token(sub):
"""
Make sure that you can submit a record and include a refresh token
"""
sub._auth_provider._refresh_token = {"api_key": "123"}

with patch("gen3.submission.requests.put") as mock_request:
mock_request().status_code = 200
mock_request().json.return_value = '{ "key": "value" }'
rec = sub.submit_record(
"prog1",
"proj1",
{
"projects": [{"code": "proj1"}],
"submitter_id": "mjmartinson",
"type": "experiment",
},
)
assert rec == mock_request().json.return_value


def test_submit_record_include_refresh_token_missing_api_key(sub):
"""
Check that there's a KeyError when submitting a record while missing an api key
"""
sub._auth_provider._refresh_token = {"missing_api_key": "123"}
with patch("gen3.submission.requests.put", side_effect=KeyError) as mock_request:
with pytest.raises(KeyError):
rec = sub.submit_record(
"prog1",
"proj1",
{
"projects": [{"code": "proj1"}],
"submitter_id": "mjmartinson",
"type": "experiment",
},
)


def test_submit_record_include_refresh_token_wrong_api_key(sub):
"""
Check that there's an Exception when submitting a record with the wrong api key
"""
sub._auth_provider._refresh_token = {"api_key": "wrong_api_key"}
with patch(
"gen3.submission.requests.put", side_effect=Exception("invalid jwt token")
) as mock_request:
with pytest.raises(Exception):
rec = sub.submit_record(
"prog1",
"proj1",
{
"projects": [{"code": "proj1"}],
"submitter_id": "mjmartinson",
"type": "experiment",
},
)


def test_export_record(sub):
Expand Down

0 comments on commit dd55435

Please sign in to comment.