Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sqs notification plugin #115

Closed
wants to merge 14 commits into from
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ RUN apt-get -q update && apt-get -q install -y \
# upgrade because of this issue https://github.com/chanzuckerberg/miniwdl/issues/607 in miniwdl
RUN pip3 install importlib-metadata==4.13.0
RUN pip3 install miniwdl==${MINIWDL_VERSION}
RUN pip3 install urllib3==1.26.16
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed for tests to pass


RUN curl -Ls https://github.com/chanzuckerberg/s3parcp/releases/download/v1.0.1/s3parcp_1.0.1_linux_amd64.tar.gz | tar -C /usr/bin -xz s3parcp

Expand All @@ -62,6 +63,7 @@ ADD miniwdl-plugins miniwdl-plugins
RUN pip install miniwdl-plugins/s3upload
RUN pip install miniwdl-plugins/sfn_wdl
RUN pip install miniwdl-plugins/s3parcp_download
RUN pip install miniwdl-plugins/sqs_notification

RUN cd /usr/bin; curl -O https://amazon-ecr-credential-helper-releases.s3.amazonaws.com/0.4.0/linux-amd64/docker-credential-ecr-login
RUN chmod +x /usr/bin/docker-credential-ecr-login
Expand Down
1 change: 1 addition & 0 deletions miniwdl-plugins/sqs_notification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# sqs_notifications
29 changes: 29 additions & 0 deletions miniwdl-plugins/sqs_notification/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python3
from setuptools import setup
from os import path

this_directory = path.abspath(path.dirname(__file__))
with open(path.join(path.dirname(__file__), "README.md")) as f:
long_description = f.read()

setup(
name="sqs_notification",
version="0.0.1",
description="miniwdl plugin for notification of task completion to Amazon SQS",
url="https://github.com/chanzuckerberg/miniwdl-s3upload",
project_urls={},
long_description=long_description,
long_description_content_type="text/markdown",
author="",
py_modules=["sqs_notification"],
python_requires=">=3.6",
setup_requires=["reentry"],
install_requires=["boto3"],
reentry_register=True,
entry_points={
"miniwdl.plugin.task": ["sqs_notification_task = sqs_notification:task"],
"miniwdl.plugin.workflow": [
"sqs_notification_workflow = sqs_notification:workflow"
],
},
)
75 changes: 75 additions & 0 deletions miniwdl-plugins/sqs_notification/sqs_notification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
TODO
"""

import os
import json
from typing import Dict

from WDL import values_to_json

import boto3

sqs_client = boto3.client("sqs", endpoint_url=os.getenv("AWS_ENDPOINT_URL"))
queue_url = os.getenv("AWS_STEP_NOTIFICATION_PLUGIN")


def process_outputs(outputs: Dict):
"""process outputs dict into string to be passed into SQS"""
# only stringify for now
return json.dumps(outputs)


def send_message(attr, body):
"""send message to SQS, eventually wrap this in a try catch to deal with throttling"""
sqs_resp = sqs_client.send_message(
QueueUrl=queue_url,
DelaySeconds=0,
MessageAttributes=attr,
MessageBody=body,
)
return sqs_resp


def task(cfg, logger, run_id, run_dir, task, **recv):
"""
on completion of any task, upload its output files to S3, and record the S3 URI corresponding
to each local file (keyed by inode) in _uploaded_files
"""
logger = logger.getChild("s3_progressive_upload")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO - Ryan: rename


# ignore inputs
recv = yield recv
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to get notifications for either the inputs/command?


# ignore command/runtime/container
recv = yield recv

message_attributes = {
"WorkflowName": {"DataType": "String", "StringValue": run_id[0]},
"TaskName": {"DataType": "String", "StringValue": run_id[-1]},
"ExecutionId": {
"DataType": "String",
"StringValue": "execution_id_to_be_passed_in",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Figure out how to actually get the execution id

},
}

outputs = process_outputs(values_to_json(recv["outputs"]))
message_body = outputs

send_message(message_attributes, message_body)

yield recv


def workflow(cfg, logger, run_id, run_dir, workflow, **recv):
"""
on workflow completion, add a file outputs.s3.json to the run directory, which is outputs.json
with local filenames rewritten to the uploaded S3 URIs (as previously recorded on completion of
each task).
"""
logger = logger.getChild("s3_progressive_upload")

# ignore inputs
recv = yield recv

yield recv
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to send a message when the workflow is finished?

2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ boto3
flake8
mypy
yq
miniwdl
miniwdl
30 changes: 30 additions & 0 deletions terraform/modules/swipe-sfn-batch-job/step_notifications.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
locals {
enable_notifications = length(var.sqs_queues) > 0
}

resource "aws_sqs_queue" "sfn_notifications_queue_dead_letter" {
for_each = { for name, opts in var.sqs_queues : name => opts if lookup(opts, "dead_letter", "true") == "true" }

name = "${var.app_name}-${each.key}-sfn-notifications-queue-dead-letter"

tags = var.tags
}

resource "aws_sqs_queue" "step_notifications_queue" {
for_each = var.sqs_queues

name = "${var.app_name}-${each.key}-sfn-notifications-queue"

// Upper-bound for handling any notification
visibility_timeout_seconds = lookup(each.value, "visibility_timeout_seconds", "120")

// Sent to dead-letter queue after maxReceiveCount tries
redrive_policy = lookup(each.value, "dead_letter", "true") == "true" ? jsonencode({
deadLetterTargetArn = aws_sqs_queue.sfn_notifications_queue_dead_letter[each.key].arn
maxReceiveCount = 3
Copy link
Collaborator Author

@rzlim08 rzlim08 Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO - Ryan: raise this value

}) : null

tags = var.tags
}


5 changes: 5 additions & 0 deletions terraform/modules/swipe-sfn-batch-job/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,8 @@ variable "docker_network" {
type = string
default = ""
}

variable "sqs_queues" {
description = "A dictionary of sqs queue names to a map of options: visibility_timeout_seconds (default: '120'), dead_letter ('true'/'false' default: 'true')"
type = map(map(string))
}
11 changes: 10 additions & 1 deletion terraform/modules/swipe-sfn/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,16 @@ module "batch_job" {
docker_network = var.docker_network
call_cache = var.call_cache
output_status_json_files = var.output_status_json_files
tags = var.tags
sqs_queues = {
"step" : {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass this into SWIPE?

dead_letter : "false",
// We have different settings for dev below b/c multiple dev machines may view
// and ignore the messages, which drives up the receiveCount. Timeout is lower
// so that the intended machine may see it faster:
visibility_timeout_seconds : "10",
},
}
tags = var.tags
}

module "sfn_io_helper" {
Expand Down
2 changes: 2 additions & 0 deletions test/terraform/moto/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ module "swipetest" {
"AWS_ENDPOINT_URL" : "http://awsnet:5000",
"AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" : "container-credentials-relative-uri",
"S3PARCP_S3_URL" : "http://awsnet:5000",
"AWS_STEP_NOTIFICATION_PLUGIN" : "http://localhost:9000/123456789012/swipe-test-step-sfn-notifications-queue"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably use a variable to set this


}

sqs_queues = {
Expand Down
51 changes: 45 additions & 6 deletions test/test_wdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,20 @@ def setUp(self) -> None:
for sfn in state_machines
if "stage-test" in sfn["name"]
][0]
self.state_change_queue_url = self.sqs.list_queues()["QueueUrls"][0]
self.state_change_queue_url = [
url
for url in self.sqs.list_queues()["QueueUrls"]
if "swipe-test-notifications" in url
][0]
self.step_change_queue_url = [
url
for url in self.sqs.list_queues()["QueueUrls"]
if "swipe-test-step" in url
][0]

# Empty the SQS queue before running tests.
_ = self.sqs.purge_queue(QueueUrl=self.state_change_queue_url)
_ = self.sqs.purge_queue(QueueUrl=self.step_change_queue_url)

def tearDown(self) -> None:
self.test_bucket.delete_objects(
Expand All @@ -291,23 +301,48 @@ def tearDown(self) -> None:
)
self.test_bucket.delete()

def retrieve_message(self, url) -> str:
""" Retrieve a single SQS message and delete it from queue"""
resp = self.sqs.receive_message(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add attributes

QueueUrl=url,
MaxNumberOfMessages=1,
)
# If no messages, just return
if not resp.get("Messages", None):
return ""

message = resp["Messages"][0]
receipt_handle = message["ReceiptHandle"]
self.sqs.delete_message(
QueueUrl=url,
ReceiptHandle=receipt_handle,
)
return message["Body"]

def _wait_sfn(
self,
sfn_input: Dict,
sfn_arn: str,
n_stages: int = 1,
expect_success: bool = True
) -> Tuple[str, Dict, List[Dict]]:
) -> Tuple[str, Dict, List[Dict], List[str]]:
execution_name = "swipe-test-{}".format(int(time.time()))
res = self.sfn.start_execution(
stateMachineArn=sfn_arn, name=execution_name, input=json.dumps(sfn_input)
)
arn = res["executionArn"]
start = time.time()
description = self.sfn.describe_execution(executionArn=arn)
step_notifications = []
while description["status"] == "RUNNING" and time.time() < start + 2 * 60:
time.sleep(10)
description = self.sfn.describe_execution(executionArn=arn)

while step_messages := self.retrieve_message(self.step_change_queue_url):
step_notifications.append(
step_messages
)

print("printing execution history", file=sys.stderr)

seen_events = set()
Expand Down Expand Up @@ -354,7 +389,7 @@ def _wait_sfn(
self.assertEqual(description["status"], "SUCCEEDED", description)
else:
self.assertEqual(description["status"], "FAILED", description)
return arn, description, messages
return arn, description, messages, step_notifications

def test_simple_sfn_wdl_workflow(self):
output_prefix = "out-1"
Expand All @@ -369,7 +404,7 @@ def test_simple_sfn_wdl_workflow(self):
},
}

arn, description, messages = self._wait_sfn(sfn_input, self.single_sfn_arn)
arn, description, messages, step_notifications = self._wait_sfn(sfn_input, self.single_sfn_arn)

output = json.loads(description["output"])
self.assertEqual(output["Result"], {
Expand All @@ -386,6 +421,10 @@ def test_simple_sfn_wdl_workflow(self):
self.assertEqual(
json.loads(messages[0]["Body"])["detail"]["lastCompletedStage"], "run"
)
self.assertEqual(
# TODO: bc of download caching this value can change, figure out if you want it to change or not
len(step_notifications), 3
)

def test_https_inputs(self):
output_prefix = "out-https-1"
Expand Down Expand Up @@ -415,7 +454,7 @@ def test_failing_wdl_workflow(self):
},
}

arn, description, messages = self._wait_sfn(sfn_input, self.single_sfn_arn, expect_success=False)
arn, description, messages, _ = self._wait_sfn(sfn_input, self.single_sfn_arn, expect_success=False)
errorType = (self.sfn.get_execution_history(executionArn=arn)["events"]
[-1]["executionFailedEventDetails"]["error"])
self.assertTrue(errorType in ["UncaughtError", "RunFailed"])
Expand Down Expand Up @@ -469,7 +508,7 @@ def test_staged_sfn_wdl_workflow(self):
},
}

_, _, messages = self._wait_sfn(sfn_input, self.stage_sfn_arn, 2)
_, _, messages, _ = self._wait_sfn(sfn_input, self.stage_sfn_arn, 2)

outputs_obj = self.test_bucket.Object(
f"{output_prefix}/test-1/happy_message.txt"
Expand Down
2 changes: 1 addition & 1 deletion version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v1.4.6
v1.4.7
Loading