Skip to content

Commit

Permalink
Merge pull request #8 from broadinstitute/se-add-caas-option
Browse files Browse the repository at this point in the history
Add option to use Cromwell-as-a-service
  • Loading branch information
samanehsan authored Apr 19, 2018
2 parents f1682bf + 35b3a4e commit f614e34
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 20 deletions.
47 changes: 32 additions & 15 deletions cromwell_tools/cromwell_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import six
import re
from tenacity import retry, wait_exponential, stop_after_delay
from oauth2client.service_account import ServiceAccountCredentials


_failed_statuses = ['Failed', 'Aborted', 'Aborting']
Expand Down Expand Up @@ -45,24 +46,34 @@ def harmonize_credentials(secrets_file=None, cromwell_username=None, cromwell_pa
return cromwell_username, cromwell_password


def _get_auth_credentials(secrets_file=None, cromwell_user=None, cromwell_password=None, caas_key=None):
if caas_key:
headers = generate_auth_header_from_key_file(caas_key)
auth = None
else:
headers = None
cromwell_user, cromwell_password = harmonize_credentials(secrets_file, cromwell_user, cromwell_password)
auth = requests.auth.HTTPBasicAuth(cromwell_user, cromwell_password)
return auth, headers


def get_workflow_statuses(
ids, cromwell_url, cromwell_user=None, cromwell_password=None, secrets_file=None):
ids, cromwell_url, cromwell_user=None, cromwell_password=None, secrets_file=None, caas_key=None):
""" Given a list of workflow ids, query cromwell url for their statuses
:param list ids:
:param str cromwell_url:
:param str cromwell_user:
:param str cromwell_password:
:param str secrets_file:
:param str caas_key: service account JSON key for cromwell-as-a-service
:return list: list of workflow statuses
"""
cromwell_user, cromwell_password = harmonize_credentials(
secrets_file, cromwell_user, cromwell_password)
statuses = []
auth, headers = _get_auth_credentials(secrets_file, cromwell_user, cromwell_password, caas_key)
for id_ in ids:
full_url = cromwell_url + '/api/workflows/v1/{0}/status'.format(id_)
auth = requests.auth.HTTPBasicAuth(cromwell_user, cromwell_password)
response = requests.get(full_url, auth=auth)
response = requests.get(full_url, auth=auth, headers=headers)
if response.status_code != 200:
print('Could not get status for {0}. Cromwell at {1} returned status {2}'.format(
id_, cromwell_url, response.status_code))
Expand All @@ -78,7 +89,7 @@ def get_workflow_statuses(

def wait_until_workflow_completes(
cromwell_url, workflow_ids, timeout_minutes, poll_interval_seconds=30, cromwell_user=None,
cromwell_password=None, secrets_file=None):
cromwell_password=None, secrets_file=None, caas_key=None):
"""
Given a list of workflow ids, wait until cromwell returns successfully for each status, or
one of the workflows fails or is aborted.
Expand All @@ -91,6 +102,7 @@ def wait_until_workflow_completes(
:param str cromwell_user:
:param str cromwell_password:
:param str secrets_file:
:param str caas_key: service account JSON key for cromwell-as-a-service
:return:
"""
cromwell_user, cromwell_password = harmonize_credentials(
Expand All @@ -101,7 +113,7 @@ def wait_until_workflow_completes(
if datetime.now() - start > timeout:
msg = 'Unfinished workflows after {0} minutes.'
raise Exception(msg.format(timeout))
statuses = get_workflow_statuses(workflow_ids, cromwell_url, cromwell_user, cromwell_password)
statuses = get_workflow_statuses(workflow_ids, cromwell_url, cromwell_user, cromwell_password, caas_key)
all_succeeded = True
for i, status in enumerate(statuses):
if status in _failed_statuses:
Expand All @@ -118,7 +130,7 @@ def wait_until_workflow_completes(
@retry(reraise=True, wait=wait_exponential(multiplier=1, max=10), stop=stop_after_delay(20))
def start_workflow(
wdl_file, inputs_file, url, options_file=None, inputs_file2=None, zip_file=None, user=None,
password=None, label=None, validate_labels=True):
password=None, caas_key=None, collection_name=None, label=None, validate_labels=True):
"""Use HTTP POST to start workflow in Cromwell and retry with exponentially increasing wait times between requests
if there are any failures. View statistics about the retries with `start_workflow.retry.statistics`.
Expand All @@ -133,6 +145,8 @@ def start_workflow(
:param _io.BytesIO zip_file: (optional) zip file containing dependencies.
:param str user: (optional) cromwell username.
:param str password: (optional) cromwell password.
:param str caas_key: (optional) service account JSON key for cromwell-as-a-service.
:param str collection_name: (optional) collection in SAM that the workflow should belong to.
:param str|_io.BytesIO label: (optional) JSON file containing a collection of key/value pairs for workflow labels.
:param bool validate_labels: (optional) Whether to validate labels or not, using cromwell-tools' built-in
validators. It is set to True by default.
Expand All @@ -153,16 +167,13 @@ def start_workflow(
files['workflowDependencies'] = zip_file
if options_file is not None:
files['workflowOptions'] = options_file

if user and password:
auth = HTTPBasicAuth(user, password)
else:
auth = None

if label:
files['labels'] = label
if caas_key and collection_name:
files['collectionName'] = collection_name

response = requests.post(url, files=files, auth=auth)
auth, headers = _get_auth_credentials(cromwell_user=user, cromwell_password=password, caas_key=caas_key)
response = requests.post(url, files=files, auth=auth, headers=headers)
response.raise_for_status()
return response

Expand Down Expand Up @@ -310,3 +321,9 @@ def validate_cromwell_label(label_object):

if err_msg != '':
raise ValueError(err_msg)


def generate_auth_header_from_key_file(json_credentials):
scopes = ['https://www.googleapis.com/auth/userinfo.profile', 'https://www.googleapis.com/auth/userinfo.email']
credentials = ServiceAccountCredentials.from_json_keyfile_name(json_credentials, scopes=scopes)
return {"Authorization": "bearer " + credentials.get_access_token().access_token}
59 changes: 57 additions & 2 deletions cromwell_tools/tests/test_cromwell_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def setUp(self):
self.url = "https://fake_url"
self.user = "fake_user"
self.password = "fake_password"
self.caas_key = "path/fake_key.json"

@requests_mock.mock()
def test_start_workflow(self, mock_request):
Expand All @@ -62,7 +63,7 @@ def _request_callback(request, context):
mock_request.post(self.url, json=_request_callback)
result = cromwell_tools.start_workflow(
self.wdl_file, self.inputs_file, self.url, self.options_file, self.inputs_file2, self.zip_file, self.user,
self.password, self.label)
self.password, label=self.label)
self.assertEqual(result.status_code, 200)
self.assertEqual(result.headers.get('test'), 'header')

Expand All @@ -81,12 +82,66 @@ def _request_callback(request, context):
with self.assertRaises(requests.HTTPError):
result = cromwell_tools.start_workflow(
self.wdl_file, self.inputs_file, self.url, self.options_file, self.inputs_file2, self.zip_file,
self.user, self.password, self.label)
self.user, self.password, label=self.label)
self.assertNotEqual(mock_request.call_count, 1)

# Reset default retry value
cromwell_tools.start_workflow.retry.stop = stop_after_delay(20)

@requests_mock.mock()
@mock.patch('cromwell_tools.cromwell_tools.generate_auth_header_from_key_file')
def test_start_workflow_in_cromwell_as_a_service(self, mock_request, mock_header):
mock_header.return_value = {"Authorization": "bearer fake_token"}
def _request_callback(request, context):
context.status_code = 200
context.headers['test'] = 'header'
return {'request': {'body': "content"}}

# Check request actions
mock_request.post(self.url, json=_request_callback)
result = cromwell_tools.start_workflow(
self.wdl_file, self.inputs_file, self.url, self.options_file, self.inputs_file2, self.zip_file,
caas_key=self.caas_key, label=self.label)
self.assertEqual(result.status_code, 200)
self.assertEqual(result.headers.get('test'), 'header')

@requests_mock.mock()
def test_get_workflow_statuses(self, mock_request):
def _request_callback(request, context):
context.status_code = 200
context.headers['test'] = 'header'
return {'request': {'body': "content"}}

def _request_callback_status(request, context):
context.status_code = 200
context.headers['test'] = 'header'
return {'status': 'Succeeded'}

ids = ["01234"]
mock_request.post(self.url, json=_request_callback)
mock_request.get(self.url + '/api/workflows/v1/{}/status'.format(ids[0]), json=_request_callback_status)
result = cromwell_tools.get_workflow_statuses(ids, self.url, self.user, self.password)
self.assertIn('Succeeded', result)

@requests_mock.mock()
@mock.patch('cromwell_tools.cromwell_tools.generate_auth_header_from_key_file')
def test_get_workflow_statuses_in_cromwell_as_a_service(self, mock_request, mock_header):
def _request_callback(request, context):
context.status_code = 200
context.headers['test'] = 'header'
return {'request': {'body': "content"}}

def _request_callback_status(request, context):
context.status_code = 200
context.headers['test'] = 'header'
return {'status': 'Succeeded'}

ids = ["01234"]
mock_request.post(self.url, json=_request_callback)
mock_request.get(self.url + '/api/workflows/v1/{}/status'.format(ids[0]), json=_request_callback_status)
result = cromwell_tools.get_workflow_statuses(ids, self.url, caas_key=self.caas_key)
self.assertIn('Succeeded', result)

@requests_mock.mock()
def test_download_http_raises_error_on_bad_status_code(self, mock_request):

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
requests==2.18.4
six==1.11.0
tenacity==4.10.0
oauth2client==4.1.2
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
packages=['cromwell_tools'],
install_requires=[
'requests==2.18.4',
'six==1.11.0'
'six==1.11.0',
'oauth2client==4.1.2',
'tenacity==4.10.0'
],
scripts=['cromwell_tools/scripts/cromwell-tools'],
include_package_data=True
Expand Down
4 changes: 2 additions & 2 deletions test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
mock
requests_mock
mock==2.0.0
requests_mock==1.4.0

0 comments on commit f614e34

Please sign in to comment.