Skip to content

Commit

Permalink
Merge pull request getmoto#2376 from rwestergren/event_source_mappings
Browse files Browse the repository at this point in the history
And event source mapping endpoints and SQS trigger support
  • Loading branch information
mikegrima authored Aug 22, 2019
2 parents 7fa46e9 + 1efd9ee commit 3a5d857
Show file tree
Hide file tree
Showing 7 changed files with 565 additions and 14 deletions.
154 changes: 147 additions & 7 deletions moto/awslambda/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import unicode_literals

import base64
import time
from collections import defaultdict
import copy
import datetime
Expand Down Expand Up @@ -31,6 +32,7 @@
from moto.s3.exceptions import MissingBucket, MissingKey
from moto import settings
from .utils import make_function_arn, make_function_ver_arn
from moto.sqs import sqs_backends

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -429,24 +431,59 @@ def _create_zipfile_from_plaintext_code(code):
class EventSourceMapping(BaseModel):
def __init__(self, spec):
# required
self.function_name = spec['FunctionName']
self.function_arn = spec['FunctionArn']
self.event_source_arn = spec['EventSourceArn']
self.starting_position = spec['StartingPosition']
self.uuid = str(uuid.uuid4())
self.last_modified = time.mktime(datetime.datetime.utcnow().timetuple())

# BatchSize service default/max mapping
batch_size_map = {
'kinesis': (100, 10000),
'dynamodb': (100, 1000),
'sqs': (10, 10),
}
source_type = self.event_source_arn.split(":")[2].lower()
batch_size_entry = batch_size_map.get(source_type)
if batch_size_entry:
# Use service default if not provided
batch_size = int(spec.get('BatchSize', batch_size_entry[0]))
if batch_size > batch_size_entry[1]:
raise ValueError("InvalidParameterValueException",
"BatchSize {} exceeds the max of {}".format(batch_size, batch_size_entry[1]))
else:
self.batch_size = batch_size
else:
raise ValueError("InvalidParameterValueException",
"Unsupported event source type")

# optional
self.batch_size = spec.get('BatchSize', 100)
self.starting_position = spec.get('StartingPosition', 'TRIM_HORIZON')
self.enabled = spec.get('Enabled', True)
self.starting_position_timestamp = spec.get('StartingPositionTimestamp',
None)

def get_configuration(self):
return {
'UUID': self.uuid,
'BatchSize': self.batch_size,
'EventSourceArn': self.event_source_arn,
'FunctionArn': self.function_arn,
'LastModified': self.last_modified,
'LastProcessingResult': '',
'State': 'Enabled' if self.enabled else 'Disabled',
'StateTransitionReason': 'User initiated'
}

@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name):
properties = cloudformation_json['Properties']
func = lambda_backends[region_name].get_function(properties['FunctionName'])
spec = {
'FunctionName': properties['FunctionName'],
'FunctionArn': func.function_arn,
'EventSourceArn': properties['EventSourceArn'],
'StartingPosition': properties['StartingPosition']
'StartingPosition': properties['StartingPosition'],
'BatchSize': properties.get('BatchSize', 100)
}
optional_properties = 'BatchSize Enabled StartingPositionTimestamp'.split()
for prop in optional_properties:
Expand All @@ -466,8 +503,10 @@ def __repr__(self):
def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name):
properties = cloudformation_json['Properties']
function_name = properties['FunctionName']
func = lambda_backends[region_name].publish_function(function_name)
spec = {
'Version': properties.get('Version')
'Version': func.version
}
return LambdaVersion(spec)

Expand Down Expand Up @@ -515,6 +554,9 @@ def list_versions_by_function(self, name):
def get_arn(self, arn):
return self._arns.get(arn, None)

def get_function_by_name_or_arn(self, input):
return self.get_function(input) or self.get_arn(input)

def put_function(self, fn):
"""
:param fn: Function
Expand Down Expand Up @@ -596,6 +638,7 @@ def all(self):
class LambdaBackend(BaseBackend):
def __init__(self, region_name):
self._lambdas = LambdaStorage()
self._event_source_mappings = {}
self.region_name = region_name

def reset(self):
Expand All @@ -617,6 +660,40 @@ def create_function(self, spec):
fn.version = ver.version
return fn

def create_event_source_mapping(self, spec):
required = [
'EventSourceArn',
'FunctionName',
]
for param in required:
if not spec.get(param):
raise RESTError('InvalidParameterValueException', 'Missing {}'.format(param))

# Validate function name
func = self._lambdas.get_function_by_name_or_arn(spec.pop('FunctionName', ''))
if not func:
raise RESTError('ResourceNotFoundException', 'Invalid FunctionName')

# Validate queue
for queue in sqs_backends[self.region_name].queues.values():
if queue.queue_arn == spec['EventSourceArn']:
if queue.lambda_event_source_mappings.get('func.function_arn'):
# TODO: Correct exception?
raise RESTError('ResourceConflictException', 'The resource already exists.')
if queue.fifo_queue:
raise RESTError('InvalidParameterValueException',
'{} is FIFO'.format(queue.queue_arn))
else:
spec.update({'FunctionArn': func.function_arn})
esm = EventSourceMapping(spec)
self._event_source_mappings[esm.uuid] = esm

# Set backend function on queue
queue.lambda_event_source_mappings[esm.function_arn] = esm

return esm
raise RESTError('ResourceNotFoundException', 'Invalid EventSourceArn')

def publish_function(self, function_name):
return self._lambdas.publish_function(function_name)

Expand All @@ -626,6 +703,33 @@ def get_function(self, function_name, qualifier=None):
def list_versions_by_function(self, function_name):
return self._lambdas.list_versions_by_function(function_name)

def get_event_source_mapping(self, uuid):
return self._event_source_mappings.get(uuid)

def delete_event_source_mapping(self, uuid):
return self._event_source_mappings.pop(uuid)

def update_event_source_mapping(self, uuid, spec):
esm = self.get_event_source_mapping(uuid)
if esm:
if spec.get('FunctionName'):
func = self._lambdas.get_function_by_name_or_arn(spec.get('FunctionName'))
esm.function_arn = func.function_arn
if 'BatchSize' in spec:
esm.batch_size = spec['BatchSize']
if 'Enabled' in spec:
esm.enabled = spec['Enabled']
return esm
return False

def list_event_source_mappings(self, event_source_arn, function_name):
esms = list(self._event_source_mappings.values())
if event_source_arn:
esms = list(filter(lambda x: x.event_source_arn == event_source_arn, esms))
if function_name:
esms = list(filter(lambda x: x.function_name == function_name, esms))
return esms

def get_function_by_arn(self, function_arn):
return self._lambdas.get_arn(function_arn)

Expand All @@ -635,7 +739,43 @@ def delete_function(self, function_name, qualifier=None):
def list_functions(self):
return self._lambdas.all()

def send_message(self, function_name, message, subject=None, qualifier=None):
def send_sqs_batch(self, function_arn, messages, queue_arn):
success = True
for message in messages:
func = self.get_function_by_arn(function_arn)
result = self._send_sqs_message(func, message, queue_arn)
if not result:
success = False
return success

def _send_sqs_message(self, func, message, queue_arn):
event = {
"Records": [
{
"messageId": message.id,
"receiptHandle": message.receipt_handle,
"body": message.body,
"attributes": {
"ApproximateReceiveCount": "1",
"SentTimestamp": "1545082649183",
"SenderId": "AIDAIENQZJOLO23YVJ4VO",
"ApproximateFirstReceiveTimestamp": "1545082649185"
},
"messageAttributes": {},
"md5OfBody": "098f6bcd4621d373cade4e832627b4f6",
"eventSource": "aws:sqs",
"eventSourceARN": queue_arn,
"awsRegion": self.region_name
}
]
}

request_headers = {}
response_headers = {}
func.invoke(json.dumps(event), request_headers, response_headers)
return 'x-amz-function-error' not in response_headers

def send_sns_message(self, function_name, message, subject=None, qualifier=None):
event = {
"Records": [
{
Expand Down
64 changes: 64 additions & 0 deletions moto/awslambda/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,31 @@ def root(self, request, full_url, headers):
else:
raise ValueError("Cannot handle request")

def event_source_mappings(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'GET':
querystring = self.querystring
event_source_arn = querystring.get('EventSourceArn', [None])[0]
function_name = querystring.get('FunctionName', [None])[0]
return self._list_event_source_mappings(event_source_arn, function_name)
elif request.method == 'POST':
return self._create_event_source_mapping(request, full_url, headers)
else:
raise ValueError("Cannot handle request")

def event_source_mapping(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
path = request.path if hasattr(request, 'path') else path_url(request.url)
uuid = path.split('/')[-1]
if request.method == 'GET':
return self._get_event_source_mapping(uuid)
elif request.method == 'PUT':
return self._update_event_source_mapping(uuid)
elif request.method == 'DELETE':
return self._delete_event_source_mapping(uuid)
else:
raise ValueError("Cannot handle request")

def function(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'GET':
Expand Down Expand Up @@ -177,6 +202,45 @@ def _create_function(self, request, full_url, headers):
config = fn.get_configuration()
return 201, {}, json.dumps(config)

def _create_event_source_mapping(self, request, full_url, headers):
try:
fn = self.lambda_backend.create_event_source_mapping(self.json_body)
except ValueError as e:
return 400, {}, json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}})
else:
config = fn.get_configuration()
return 201, {}, json.dumps(config)

def _list_event_source_mappings(self, event_source_arn, function_name):
esms = self.lambda_backend.list_event_source_mappings(event_source_arn, function_name)
result = {
'EventSourceMappings': [esm.get_configuration() for esm in esms]
}
return 200, {}, json.dumps(result)

def _get_event_source_mapping(self, uuid):
result = self.lambda_backend.get_event_source_mapping(uuid)
if result:
return 200, {}, json.dumps(result.get_configuration())
else:
return 404, {}, "{}"

def _update_event_source_mapping(self, uuid):
result = self.lambda_backend.update_event_source_mapping(uuid, self.json_body)
if result:
return 202, {}, json.dumps(result.get_configuration())
else:
return 404, {}, "{}"

def _delete_event_source_mapping(self, uuid):
esm = self.lambda_backend.delete_event_source_mapping(uuid)
if esm:
json_result = esm.get_configuration()
json_result.update({'State': 'Deleting'})
return 202, {}, json.dumps(json_result)
else:
return 404, {}, "{}"

def _publish_function(self, request, full_url, headers):
function_name = self.path.rsplit('/', 2)[-2]

Expand Down
2 changes: 2 additions & 0 deletions moto/awslambda/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
'{0}/(?P<api_version>[^/]+)/functions/?$': response.root,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/?$': response.function,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$': response.versions,
r'{0}/(?P<api_version>[^/]+)/event-source-mappings/?$': response.event_source_mappings,
r'{0}/(?P<api_version>[^/]+)/event-source-mappings/(?P<UUID>[\w_-]+)/?$': response.event_source_mapping,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$': response.invoke,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$': response.invoke_async,
r'{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)': response.tag,
Expand Down
2 changes: 1 addition & 1 deletion moto/sns/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def publish(self, message, message_id, subject=None,
else:
assert False

lambda_backends[region].send_message(function_name, message, subject=subject, qualifier=qualifier)
lambda_backends[region].send_sns_message(function_name, message, subject=subject, qualifier=qualifier)

def _matches_filter_policy(self, message_attributes):
# TODO: support Anything-but matching, prefix matching and
Expand Down
29 changes: 29 additions & 0 deletions moto/sqs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def __init__(self, name, region, **kwargs):
self.name)
self.dead_letter_queue = None

self.lambda_event_source_mappings = {}

# default settings for a non fifo queue
defaults = {
'ContentBasedDeduplication': 'false',
Expand Down Expand Up @@ -360,6 +362,33 @@ def messages(self):

def add_message(self, message):
self._messages.append(message)
from moto.awslambda import lambda_backends
for arn, esm in self.lambda_event_source_mappings.items():
backend = sqs_backends[self.region]

"""
Lambda polls the queue and invokes your function synchronously with an event
that contains queue messages. Lambda reads messages in batches and invokes
your function once for each batch. When your function successfully processes
a batch, Lambda deletes its messages from the queue.
"""
messages = backend.receive_messages(
self.name,
esm.batch_size,
self.receive_message_wait_time_seconds,
self.visibility_timeout,
)

result = lambda_backends[self.region].send_sqs_batch(
arn,
messages,
self.queue_arn,
)

if result:
[backend.delete_message(self.name, m.receipt_handle) for m in messages]
else:
[backend.change_message_visibility(self.name, m.receipt_handle, 0) for m in messages]

def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
Expand Down
Loading

0 comments on commit 3a5d857

Please sign in to comment.