@@ -18,11 +18,11 @@ import (
1818)
1919
2020type snsSqs struct {
21- // key is the topic name, value is the ARN of the topic
21+ // key is the topic name, value is the ARN of the topic.
2222 topics map [string ]string
23- // key is the sanitized topic name, value is the actual topic name
23+ // key is the sanitized topic name, value is the actual topic name.
2424 topicSanitized map [string ]string
25- // key is the topic name, value holds the ARN of the queue and its url
25+ // key is the topic name, value holds the ARN of the queue and its url.
2626 queues map [string ]* sqsQueueInfo
2727 snsClient * sns.SNS
2828 sqsClient * sqs.SQS
@@ -37,31 +37,31 @@ type sqsQueueInfo struct {
3737}
3838
3939type snsSqsMetadata struct {
40- // name of the queue for this application. The is provided by the runtime as "consumerID"
40+ // name of the queue for this application. The is provided by the runtime as "consumerID".
4141 sqsQueueName string
42- // name of the dead letter queue for this application
42+ // name of the dead letter queue for this application.
4343 sqsDeadLettersQueueName string
4444 // aws endpoint for the component to use.
4545 Endpoint string
46- // access key to use for accessing sqs/sns
46+ // access key to use for accessing sqs/sns.
4747 AccessKey string
48- // secret key to use for accessing sqs/sns
48+ // secret key to use for accessing sqs/sns.
4949 SecretKey string
5050 // aws session token to use.
5151 SessionToken string
52- // aws region in which SNS/SQS should create resources
52+ // aws region in which SNS/SQS should create resources.
5353 Region string
5454
55- // amount of time in seconds that a message is hidden from receive requests after it is sent to a subscriber. Default: 10
55+ // amount of time in seconds that a message is hidden from receive requests after it is sent to a subscriber. Default: 10.
5656 messageVisibilityTimeout int64
57- // number of times to resend a message after processing of that message fails before removing that message from the queue. Default: 10
57+ // number of times to resend a message after processing of that message fails before removing that message from the queue. Default: 10.
5858 messageRetryLimit int64
5959 // if sqsDeadLettersQueueName is set to a value, then the messageReceiveLimit defines the number of times a message is received
60- // before it is moved to the dead-letters queue. This value must be smaller than messageRetryLimit
60+ // before it is moved to the dead-letters queue. This value must be smaller than messageRetryLimit.
6161 messageReceiveLimit int64
62- // amount of time to await receipt of a message before making another request. Default: 1
62+ // amount of time to await receipt of a message before making another request. Default: 1.
6363 messageWaitTimeSeconds int64
64- // maximum number of messages to receive from the queue at a time. Default: 10, Maximum: 10
64+ // maximum number of messages to receive from the queue at a time. Default: 10, Maximum: 10.
6565 messageMaxNumber int64
6666}
6767
@@ -194,7 +194,7 @@ func (s *snsSqs) getSnsSqsMetatdata(metadata pubsub.Metadata) (*snsSqsMetadata,
194194 md .messageReceiveLimit = messageReceiveLimit
195195 }
196196
197- // XOR on having either a valid messageReceiveLimit and invalid sqsDeadLettersQueueName, and vice versa
197+ // XOR on having either a valid messageReceiveLimit and invalid sqsDeadLettersQueueName, and vice versa.
198198 if (md .messageReceiveLimit > 0 || len (md .sqsDeadLettersQueueName ) > 0 ) && ! (md .messageReceiveLimit > 0 && len (md .sqsDeadLettersQueueName ) > 0 ) {
199199 return nil , errors .New ("to use SQS dead letters queue, messageReceiveLimit and sqsDeadLettersQueueName must both be set to a value" )
200200 }
@@ -243,13 +243,13 @@ func (s *snsSqs) Init(metadata pubsub.Metadata) error {
243243 s .metadata = md
244244
245245 // both Publish and Subscribe need reference the topic ARN
246- // track these ARNs in this map
246+ // track these ARNs in this map.
247247 s .topics = make (map [string ]string )
248248 s .topicSanitized = make (map [string ]string )
249249 s .queues = make (map [string ]* sqsQueueInfo )
250250 sess , err := aws_auth .GetClient (md .AccessKey , md .SecretKey , md .SessionToken , md .Region , md .Endpoint )
251251 if err != nil {
252- return err
252+ return fmt . Errorf ( "error creating an AWS client: %w" , err )
253253 }
254254 s .snsClient = sns .New (sess )
255255 s .sqsClient = sqs .New (sess )
@@ -264,7 +264,7 @@ func (s *snsSqs) createTopic(topic string) (string, string, error) {
264264 Tags : []* sns.Tag {{Key : aws .String (awsSnsTopicNameKey ), Value : aws .String (topic )}},
265265 })
266266 if err != nil {
267- return "" , "" , err
267+ return "" , "" , fmt . Errorf ( "error while creating an SNS topic: %w" , err )
268268 }
269269
270270 return * (createTopicResponse .TopicArn ), sanitizedName , nil
@@ -276,12 +276,12 @@ func (s *snsSqs) getOrCreateTopic(topic string) (string, error) {
276276 topicArn , ok := s .topics [topic ]
277277
278278 if ok {
279- s .logger .Debugf ("Found existing topic ARN for topic %s: %s" , topic , topicArn )
279+ s .logger .Debugf ("found existing topic ARN for topic %s: %s" , topic , topicArn )
280280
281281 return topicArn , nil
282282 }
283283
284- s .logger .Debugf ("No topic ARN found for %s\n Creating topic instead." , topic )
284+ s .logger .Debugf ("no topic ARN found for %s\n Creating topic instead." , topic )
285285
286286 topicArn , sanitizedName , err := s .createTopic (topic )
287287 if err != nil {
@@ -290,7 +290,7 @@ func (s *snsSqs) getOrCreateTopic(topic string) (string, error) {
290290 return "" , err
291291 }
292292
293- // record topic ARN
293+ // record topic ARN.
294294 s .topics [topic ] = topicArn
295295 s .topicSanitized [sanitizedName ] = topic
296296
@@ -303,7 +303,7 @@ func (s *snsSqs) createQueue(queueName string) (*sqsQueueInfo, error) {
303303 Tags : map [string ]* string {awsSqsQueueNameKey : aws .String (queueName )},
304304 })
305305 if err != nil {
306- return nil , err
306+ return nil , fmt . Errorf ( "error creaing an SQS queue: %w" , err )
307307 }
308308
309309 queueAttributesResponse , err := s .sqsClient .GetQueueAttributes (& sqs.GetQueueAttributesInput {
@@ -314,25 +314,6 @@ func (s *snsSqs) createQueue(queueName string) (*sqsQueueInfo, error) {
314314 s .logger .Errorf ("error fetching queue attributes for %s: %v" , queueName , err )
315315 }
316316
317- // add permissions to allow SNS to send messages to this queue
318- _ , err = s .sqsClient .SetQueueAttributes (& (sqs.SetQueueAttributesInput {
319- Attributes : map [string ]* string {
320- "Policy" : aws .String (fmt .Sprintf (`{
321- "Statement": [{
322- "Effect":"Allow",
323- "Principal":"*",
324- "Action":"sqs:SendMessage",
325- "Resource":"%s"
326- }]
327- }` , * (queueAttributesResponse .Attributes ["QueueArn" ]))),
328- },
329- QueueUrl : createQueueResponse .QueueUrl ,
330- }))
331-
332- if err != nil {
333- return nil , err
334- }
335-
336317 return & sqsQueueInfo {
337318 arn : * (queueAttributesResponse .Attributes ["QueueArn" ]),
338319 url : * (createQueueResponse .QueueUrl ),
@@ -347,7 +328,7 @@ func (s *snsSqs) getOrCreateQueue(queueName string) (*sqsQueueInfo, error) {
347328
348329 return queueArn , nil
349330 }
350- // creating queues is idempotent, the names serve as unique keys among a given region
331+ // creating queues is idempotent, the names serve as unique keys among a given region.
351332 s .logger .Debugf ("No queue arn found for %s\n Creating queue" , queueName )
352333
353334 queueInfo , err := s .createQueue (queueName )
@@ -375,9 +356,10 @@ func (s *snsSqs) Publish(req *pubsub.PublishRequest) error {
375356 })
376357
377358 if err != nil {
378- s .logger .Errorf ("error publishing topic %s with topic ARN %s: %v" , req .Topic , topicArn , err )
359+ wrappedErr := fmt .Errorf ("error publishing to topic: %s with topic ARN %s: %v" , req .Topic , topicArn , err )
360+ s .logger .Error (wrappedErr )
379361
380- return err
362+ return wrappedErr
381363 }
382364
383365 return nil
@@ -398,11 +380,11 @@ func (s *snsSqs) acknowledgeMessage(queueURL string, receiptHandle *string) erro
398380 ReceiptHandle : receiptHandle ,
399381 })
400382
401- return err
383+ return fmt . Errorf ( "error deleting SQS message: %w" , err )
402384}
403385
404386func (s * snsSqs ) handleMessage (message * sqs.Message , queueInfo , deadLettersQueueInfo * sqsQueueInfo , handler pubsub.Handler ) error {
405- // if this message has been received > x times, delete from queue, it's borked
387+ // if this message has been received > x times, delete from queue, it's borked.
406388 recvCount , ok := message .Attributes [sqs .MessageSystemAttributeNameApproximateReceiveCount ]
407389
408390 if ! ok {
@@ -425,7 +407,7 @@ func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueue
425407 "message received greater than %v times, deleting this message without further processing" , s .metadata .messageRetryLimit )
426408 }
427409 // ... else, there is no need to actively do something if we reached the limit defined in messageReceiveLimit as the message had
428- // already been moved to the dead-letters queue by SQS
410+ // already been moved to the dead-letters queue by SQS.
429411 if deadLettersQueueInfo != nil && recvCountInt >= s .metadata .messageReceiveLimit {
430412 s .logger .Warnf (
431413 "message received greater than %v times, moving this message without further processing to dead-letters queue: %v" , s .metadata .messageReceiveLimit , s .metadata .sqsDeadLettersQueueName )
@@ -450,15 +432,15 @@ func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueue
450432 return fmt .Errorf ("error handling message: %w" , err )
451433 }
452434
453- // otherwise, there was no error, acknowledge the message
435+ // otherwise, there was no error, acknowledge the message.
454436 return s .acknowledgeMessage (queueInfo .url , message .ReceiptHandle )
455437}
456438
457439func (s * snsSqs ) consumeSubscription (queueInfo , deadLettersQueueInfo * sqsQueueInfo , handler pubsub.Handler ) {
458440 go func () {
459441 for {
460442 messageResponse , err := s .sqsClient .ReceiveMessage (& sqs.ReceiveMessageInput {
461- // use this property to decide when a message should be discarded
443+ // use this property to decide when a message should be discarded.
462444 AttributeNames : []* string {
463445 aws .String (sqs .MessageSystemAttributeNameApproximateReceiveCount ),
464446 },
@@ -473,7 +455,7 @@ func (s *snsSqs) consumeSubscription(queueInfo, deadLettersQueueInfo *sqsQueueIn
473455 continue
474456 }
475457
476- // retry receiving messages
458+ // retry receiving messages.
477459 if len (messageResponse .Messages ) < 1 {
478460 s .logger .Debug ("No messages received, requesting again" )
479461
@@ -495,7 +477,7 @@ func (s *snsSqs) createDeadLettersQueue() (*sqsQueueInfo, error) {
495477 var deadLettersQueueInfo * sqsQueueInfo
496478 deadLettersQueueInfo , err := s .getOrCreateQueue (s .metadata .sqsDeadLettersQueueName )
497479 if err != nil {
498- s .logger .Errorf ("error retrieving SQS dead-letter queue: %v " , err )
480+ s .logger .Errorf ("error retrieving SQS dead-letter queue: %w " , err )
499481
500482 return nil , err
501483 }
@@ -511,9 +493,10 @@ func (s *snsSqs) createQueueAttributesWithDeadLetters(queueInfo, deadLettersQueu
511493
512494 b , err := json .Marshal (policy )
513495 if err != nil {
514- s .logger .Errorf ("error marshalling dead-letters queue policy: %v" , err )
496+ wrappedErr := fmt .Errorf ("error marshalling dead-letters queue policy: %w" , err )
497+ s .logger .Error (wrappedErr )
515498
516- return nil , err
499+ return nil , wrappedErr
517500 }
518501
519502 sqsSetQueueAttributesInput := & sqs.SetQueueAttributesInput {
@@ -526,55 +509,89 @@ func (s *snsSqs) createQueueAttributesWithDeadLetters(queueInfo, deadLettersQueu
526509 return sqsSetQueueAttributesInput , nil
527510}
528511
512+ func (s * snsSqs ) restrictQueuePublishPolicyToOnlySNS (sqsQueueInfo * sqsQueueInfo , snsARN string ) error {
513+ // only permit SNS to send messages to SQS using the created subscription.
514+ if _ , err := s .sqsClient .SetQueueAttributes (& (sqs.SetQueueAttributesInput {
515+ Attributes : map [string ]* string {
516+ "Policy" : aws .String (fmt .Sprintf (`{
517+ "Version": "2012-10-17",
518+ "Statement": [{
519+ "Effect":"Allow",
520+ "Principal":{"Service": "sns.amazonaws.com"},
521+ "Action":"sqs:SendMessage",
522+ "Resource":"%s",
523+ "Condition": {
524+ "ArnEquals":{
525+ "aws:SourceArn":"%s"
526+ }
527+ }
528+ }]
529+ }` , sqsQueueInfo .arn , snsARN )),
530+ },
531+ QueueUrl : & sqsQueueInfo .url ,
532+ })); err != nil {
533+ return fmt .Errorf ("error setting queue subscription policy: %w" , err )
534+ }
535+
536+ return nil
537+ }
538+
529539func (s * snsSqs ) Subscribe (req pubsub.SubscribeRequest , handler pubsub.Handler ) error {
530- // subscribers declare a topic ARN
531- // and declare a SQS queue to use
532- // these should be idempotent
533- // queues should not be created if they exist
540+ // subscribers declare a topic ARN and declare a SQS queue to use
541+ // these should be idempotent - queues should not be created if they exist.
534542 topicArn , err := s .getOrCreateTopic (req .Topic )
535543 if err != nil {
536- s .logger .Errorf ("error getting topic ARN for %s: %v " , req .Topic , err )
544+ s .logger .Errorf ("error getting topic ARN for %s: %w " , req .Topic , err )
537545
538546 return err
539547 }
540548
541- // this is the ID of the application, it is supplied via runtime as "consumerID"
549+ // this is the ID of the application, it is supplied via runtime as "consumerID".
542550 var queueInfo * sqsQueueInfo
543551 queueInfo , err = s .getOrCreateQueue (s .metadata .sqsQueueName )
544552 if err != nil {
545- s .logger .Errorf ("error retrieving SQS queue: %v " , err )
553+ s .logger .Errorf ("error retrieving SQS queue: %w " , err )
546554
547555 return err
548556 }
549557
558+ // only after a SQS queue and SNS topic had been setup, we restrict the SendMessage action to SNS as sole source
559+ // to prevent anyone but SNS to publish message to SQS.
560+ err = s .restrictQueuePublishPolicyToOnlySNS (queueInfo , topicArn )
561+ if err != nil {
562+ s .logger .Errorf ("error setting sns-sqs subscription policy: %w" , err )
563+
564+ return err
565+ }
566+
567+ // apply the dead letters queue attributes to the current queue.
550568 var deadLettersQueueInfo * sqsQueueInfo
551569 if len (s .metadata .sqsDeadLettersQueueName ) > 0 {
552570 var derr error
553571 deadLettersQueueInfo , derr = s .createDeadLettersQueue ()
554572 if derr != nil {
555- s .logger .Errorf ("error creating dead-letter queue: %v " , derr )
573+ s .logger .Errorf ("error creating dead-letter queue: %w " , derr )
556574
557575 return derr
558576 }
559577
560578 var sqsSetQueueAttributesInput * sqs.SetQueueAttributesInput
561579 sqsSetQueueAttributesInput , derr = s .createQueueAttributesWithDeadLetters (queueInfo , deadLettersQueueInfo )
562580 if derr != nil {
563- s .logger .Errorf ("error creatubg queue attributes for dead-letter queue: %v " , derr )
581+ s .logger .Errorf ("error creatubg queue attributes for dead-letter queue: %w " , derr )
564582
565583 return derr
566584 }
567585 _ , derr = s .sqsClient .SetQueueAttributes (sqsSetQueueAttributesInput )
568586 if derr != nil {
569- s .logger .Errorf ("error updating queue attributes with dead-letter queue: %v" , derr )
587+ wrappedErr := fmt .Errorf ("error updating queue attributes with dead-letter queue: %w" , derr )
588+ s .logger .Error (wrappedErr )
570589
571- return derr
590+ return wrappedErr
572591 }
573592 }
574593
575- // apply the dead letters queue attributes to the current queue
576-
577- // subscription creation is idempotent. Subscriptions are unique by topic/queue
594+ // subscription creation is idempotent. Subscriptions are unique by topic/queue.
578595 subscribeOutput , err := s .snsClient .Subscribe (& sns.SubscribeInput {
579596 Attributes : nil ,
580597 Endpoint : & queueInfo .arn , // create SQS queue per subscription
@@ -583,9 +600,10 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
583600 TopicArn : & topicArn ,
584601 })
585602 if err != nil {
586- s .logger .Errorf ("error subscribing to topic %s: %v" , req .Topic , err )
603+ wrappedErr := fmt .Errorf ("error subscribing to topic %s: %w" , req .Topic , err )
604+ s .logger .Error (wrappedErr )
587605
588- return err
606+ return wrappedErr
589607 }
590608
591609 s .subscriptions = append (s .subscriptions , subscribeOutput .SubscriptionArn )
0 commit comments