Skip to content

Commit e9deaf3

Browse files
Amit Moryaron2artursouzadapr-bot
authored
Snssqs subscription policy (#1259)
* bugfix for sns topic deletion upon termination * removed upstream github workflow files * gitignore * restrict SQS send message policy * linting mostly of unwrapped errors * refactoring * pr changes * Update .gitignore * Update dapr-bot-schedule.yml * Update dapr-bot-schedule.yml Co-authored-by: Yaron Schneider <yaronsc@microsoft.com> Co-authored-by: Artur Souza <artursouza.ms@outlook.com> Co-authored-by: Dapr Bot <56698301+dapr-bot@users.noreply.github.com>
1 parent f6a64f7 commit e9deaf3

File tree

1 file changed

+86
-68
lines changed

1 file changed

+86
-68
lines changed

pubsub/aws/snssqs/snssqs.go

Lines changed: 86 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ import (
1818
)
1919

2020
type snsSqs struct {
21-
// key is the topic name, value is the ARN of the topic
21+
// key is the topic name, value is the ARN of the topic.
2222
topics map[string]string
23-
// key is the sanitized topic name, value is the actual topic name
23+
// key is the sanitized topic name, value is the actual topic name.
2424
topicSanitized map[string]string
25-
// key is the topic name, value holds the ARN of the queue and its url
25+
// key is the topic name, value holds the ARN of the queue and its url.
2626
queues map[string]*sqsQueueInfo
2727
snsClient *sns.SNS
2828
sqsClient *sqs.SQS
@@ -37,31 +37,31 @@ type sqsQueueInfo struct {
3737
}
3838

3939
type snsSqsMetadata struct {
40-
// name of the queue for this application. The is provided by the runtime as "consumerID"
40+
// name of the queue for this application. The is provided by the runtime as "consumerID".
4141
sqsQueueName string
42-
// name of the dead letter queue for this application
42+
// name of the dead letter queue for this application.
4343
sqsDeadLettersQueueName string
4444
// aws endpoint for the component to use.
4545
Endpoint string
46-
// access key to use for accessing sqs/sns
46+
// access key to use for accessing sqs/sns.
4747
AccessKey string
48-
// secret key to use for accessing sqs/sns
48+
// secret key to use for accessing sqs/sns.
4949
SecretKey string
5050
// aws session token to use.
5151
SessionToken string
52-
// aws region in which SNS/SQS should create resources
52+
// aws region in which SNS/SQS should create resources.
5353
Region string
5454

55-
// amount of time in seconds that a message is hidden from receive requests after it is sent to a subscriber. Default: 10
55+
// amount of time in seconds that a message is hidden from receive requests after it is sent to a subscriber. Default: 10.
5656
messageVisibilityTimeout int64
57-
// number of times to resend a message after processing of that message fails before removing that message from the queue. Default: 10
57+
// number of times to resend a message after processing of that message fails before removing that message from the queue. Default: 10.
5858
messageRetryLimit int64
5959
// if sqsDeadLettersQueueName is set to a value, then the messageReceiveLimit defines the number of times a message is received
60-
// before it is moved to the dead-letters queue. This value must be smaller than messageRetryLimit
60+
// before it is moved to the dead-letters queue. This value must be smaller than messageRetryLimit.
6161
messageReceiveLimit int64
62-
// amount of time to await receipt of a message before making another request. Default: 1
62+
// amount of time to await receipt of a message before making another request. Default: 1.
6363
messageWaitTimeSeconds int64
64-
// maximum number of messages to receive from the queue at a time. Default: 10, Maximum: 10
64+
// maximum number of messages to receive from the queue at a time. Default: 10, Maximum: 10.
6565
messageMaxNumber int64
6666
}
6767

@@ -194,7 +194,7 @@ func (s *snsSqs) getSnsSqsMetatdata(metadata pubsub.Metadata) (*snsSqsMetadata,
194194
md.messageReceiveLimit = messageReceiveLimit
195195
}
196196

197-
// XOR on having either a valid messageReceiveLimit and invalid sqsDeadLettersQueueName, and vice versa
197+
// XOR on having either a valid messageReceiveLimit and invalid sqsDeadLettersQueueName, and vice versa.
198198
if (md.messageReceiveLimit > 0 || len(md.sqsDeadLettersQueueName) > 0) && !(md.messageReceiveLimit > 0 && len(md.sqsDeadLettersQueueName) > 0) {
199199
return nil, errors.New("to use SQS dead letters queue, messageReceiveLimit and sqsDeadLettersQueueName must both be set to a value")
200200
}
@@ -243,13 +243,13 @@ func (s *snsSqs) Init(metadata pubsub.Metadata) error {
243243
s.metadata = md
244244

245245
// both Publish and Subscribe need reference the topic ARN
246-
// track these ARNs in this map
246+
// track these ARNs in this map.
247247
s.topics = make(map[string]string)
248248
s.topicSanitized = make(map[string]string)
249249
s.queues = make(map[string]*sqsQueueInfo)
250250
sess, err := aws_auth.GetClient(md.AccessKey, md.SecretKey, md.SessionToken, md.Region, md.Endpoint)
251251
if err != nil {
252-
return err
252+
return fmt.Errorf("error creating an AWS client: %w", err)
253253
}
254254
s.snsClient = sns.New(sess)
255255
s.sqsClient = sqs.New(sess)
@@ -264,7 +264,7 @@ func (s *snsSqs) createTopic(topic string) (string, string, error) {
264264
Tags: []*sns.Tag{{Key: aws.String(awsSnsTopicNameKey), Value: aws.String(topic)}},
265265
})
266266
if err != nil {
267-
return "", "", err
267+
return "", "", fmt.Errorf("error while creating an SNS topic: %w", err)
268268
}
269269

270270
return *(createTopicResponse.TopicArn), sanitizedName, nil
@@ -276,12 +276,12 @@ func (s *snsSqs) getOrCreateTopic(topic string) (string, error) {
276276
topicArn, ok := s.topics[topic]
277277

278278
if ok {
279-
s.logger.Debugf("Found existing topic ARN for topic %s: %s", topic, topicArn)
279+
s.logger.Debugf("found existing topic ARN for topic %s: %s", topic, topicArn)
280280

281281
return topicArn, nil
282282
}
283283

284-
s.logger.Debugf("No topic ARN found for %s\n Creating topic instead.", topic)
284+
s.logger.Debugf("no topic ARN found for %s\n Creating topic instead.", topic)
285285

286286
topicArn, sanitizedName, err := s.createTopic(topic)
287287
if err != nil {
@@ -290,7 +290,7 @@ func (s *snsSqs) getOrCreateTopic(topic string) (string, error) {
290290
return "", err
291291
}
292292

293-
// record topic ARN
293+
// record topic ARN.
294294
s.topics[topic] = topicArn
295295
s.topicSanitized[sanitizedName] = topic
296296

@@ -303,7 +303,7 @@ func (s *snsSqs) createQueue(queueName string) (*sqsQueueInfo, error) {
303303
Tags: map[string]*string{awsSqsQueueNameKey: aws.String(queueName)},
304304
})
305305
if err != nil {
306-
return nil, err
306+
return nil, fmt.Errorf("error creaing an SQS queue: %w", err)
307307
}
308308

309309
queueAttributesResponse, err := s.sqsClient.GetQueueAttributes(&sqs.GetQueueAttributesInput{
@@ -314,25 +314,6 @@ func (s *snsSqs) createQueue(queueName string) (*sqsQueueInfo, error) {
314314
s.logger.Errorf("error fetching queue attributes for %s: %v", queueName, err)
315315
}
316316

317-
// add permissions to allow SNS to send messages to this queue
318-
_, err = s.sqsClient.SetQueueAttributes(&(sqs.SetQueueAttributesInput{
319-
Attributes: map[string]*string{
320-
"Policy": aws.String(fmt.Sprintf(`{
321-
"Statement": [{
322-
"Effect":"Allow",
323-
"Principal":"*",
324-
"Action":"sqs:SendMessage",
325-
"Resource":"%s"
326-
}]
327-
}`, *(queueAttributesResponse.Attributes["QueueArn"]))),
328-
},
329-
QueueUrl: createQueueResponse.QueueUrl,
330-
}))
331-
332-
if err != nil {
333-
return nil, err
334-
}
335-
336317
return &sqsQueueInfo{
337318
arn: *(queueAttributesResponse.Attributes["QueueArn"]),
338319
url: *(createQueueResponse.QueueUrl),
@@ -347,7 +328,7 @@ func (s *snsSqs) getOrCreateQueue(queueName string) (*sqsQueueInfo, error) {
347328

348329
return queueArn, nil
349330
}
350-
// creating queues is idempotent, the names serve as unique keys among a given region
331+
// creating queues is idempotent, the names serve as unique keys among a given region.
351332
s.logger.Debugf("No queue arn found for %s\nCreating queue", queueName)
352333

353334
queueInfo, err := s.createQueue(queueName)
@@ -375,9 +356,10 @@ func (s *snsSqs) Publish(req *pubsub.PublishRequest) error {
375356
})
376357

377358
if err != nil {
378-
s.logger.Errorf("error publishing topic %s with topic ARN %s: %v", req.Topic, topicArn, err)
359+
wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %v", req.Topic, topicArn, err)
360+
s.logger.Error(wrappedErr)
379361

380-
return err
362+
return wrappedErr
381363
}
382364

383365
return nil
@@ -398,11 +380,11 @@ func (s *snsSqs) acknowledgeMessage(queueURL string, receiptHandle *string) erro
398380
ReceiptHandle: receiptHandle,
399381
})
400382

401-
return err
383+
return fmt.Errorf("error deleting SQS message: %w", err)
402384
}
403385

404386
func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueueInfo *sqsQueueInfo, handler pubsub.Handler) error {
405-
// if this message has been received > x times, delete from queue, it's borked
387+
// if this message has been received > x times, delete from queue, it's borked.
406388
recvCount, ok := message.Attributes[sqs.MessageSystemAttributeNameApproximateReceiveCount]
407389

408390
if !ok {
@@ -425,7 +407,7 @@ func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueue
425407
"message received greater than %v times, deleting this message without further processing", s.metadata.messageRetryLimit)
426408
}
427409
// ... else, there is no need to actively do something if we reached the limit defined in messageReceiveLimit as the message had
428-
// already been moved to the dead-letters queue by SQS
410+
// already been moved to the dead-letters queue by SQS.
429411
if deadLettersQueueInfo != nil && recvCountInt >= s.metadata.messageReceiveLimit {
430412
s.logger.Warnf(
431413
"message received greater than %v times, moving this message without further processing to dead-letters queue: %v", s.metadata.messageReceiveLimit, s.metadata.sqsDeadLettersQueueName)
@@ -450,15 +432,15 @@ func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueue
450432
return fmt.Errorf("error handling message: %w", err)
451433
}
452434

453-
// otherwise, there was no error, acknowledge the message
435+
// otherwise, there was no error, acknowledge the message.
454436
return s.acknowledgeMessage(queueInfo.url, message.ReceiptHandle)
455437
}
456438

457439
func (s *snsSqs) consumeSubscription(queueInfo, deadLettersQueueInfo *sqsQueueInfo, handler pubsub.Handler) {
458440
go func() {
459441
for {
460442
messageResponse, err := s.sqsClient.ReceiveMessage(&sqs.ReceiveMessageInput{
461-
// use this property to decide when a message should be discarded
443+
// use this property to decide when a message should be discarded.
462444
AttributeNames: []*string{
463445
aws.String(sqs.MessageSystemAttributeNameApproximateReceiveCount),
464446
},
@@ -473,7 +455,7 @@ func (s *snsSqs) consumeSubscription(queueInfo, deadLettersQueueInfo *sqsQueueIn
473455
continue
474456
}
475457

476-
// retry receiving messages
458+
// retry receiving messages.
477459
if len(messageResponse.Messages) < 1 {
478460
s.logger.Debug("No messages received, requesting again")
479461

@@ -495,7 +477,7 @@ func (s *snsSqs) createDeadLettersQueue() (*sqsQueueInfo, error) {
495477
var deadLettersQueueInfo *sqsQueueInfo
496478
deadLettersQueueInfo, err := s.getOrCreateQueue(s.metadata.sqsDeadLettersQueueName)
497479
if err != nil {
498-
s.logger.Errorf("error retrieving SQS dead-letter queue: %v", err)
480+
s.logger.Errorf("error retrieving SQS dead-letter queue: %w", err)
499481

500482
return nil, err
501483
}
@@ -511,9 +493,10 @@ func (s *snsSqs) createQueueAttributesWithDeadLetters(queueInfo, deadLettersQueu
511493

512494
b, err := json.Marshal(policy)
513495
if err != nil {
514-
s.logger.Errorf("error marshalling dead-letters queue policy: %v", err)
496+
wrappedErr := fmt.Errorf("error marshalling dead-letters queue policy: %w", err)
497+
s.logger.Error(wrappedErr)
515498

516-
return nil, err
499+
return nil, wrappedErr
517500
}
518501

519502
sqsSetQueueAttributesInput := &sqs.SetQueueAttributesInput{
@@ -526,55 +509,89 @@ func (s *snsSqs) createQueueAttributesWithDeadLetters(queueInfo, deadLettersQueu
526509
return sqsSetQueueAttributesInput, nil
527510
}
528511

512+
func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(sqsQueueInfo *sqsQueueInfo, snsARN string) error {
513+
// only permit SNS to send messages to SQS using the created subscription.
514+
if _, err := s.sqsClient.SetQueueAttributes(&(sqs.SetQueueAttributesInput{
515+
Attributes: map[string]*string{
516+
"Policy": aws.String(fmt.Sprintf(`{
517+
"Version": "2012-10-17",
518+
"Statement": [{
519+
"Effect":"Allow",
520+
"Principal":{"Service": "sns.amazonaws.com"},
521+
"Action":"sqs:SendMessage",
522+
"Resource":"%s",
523+
"Condition": {
524+
"ArnEquals":{
525+
"aws:SourceArn":"%s"
526+
}
527+
}
528+
}]
529+
}`, sqsQueueInfo.arn, snsARN)),
530+
},
531+
QueueUrl: &sqsQueueInfo.url,
532+
})); err != nil {
533+
return fmt.Errorf("error setting queue subscription policy: %w", err)
534+
}
535+
536+
return nil
537+
}
538+
529539
func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler) error {
530-
// subscribers declare a topic ARN
531-
// and declare a SQS queue to use
532-
// these should be idempotent
533-
// queues should not be created if they exist
540+
// subscribers declare a topic ARN and declare a SQS queue to use
541+
// these should be idempotent - queues should not be created if they exist.
534542
topicArn, err := s.getOrCreateTopic(req.Topic)
535543
if err != nil {
536-
s.logger.Errorf("error getting topic ARN for %s: %v", req.Topic, err)
544+
s.logger.Errorf("error getting topic ARN for %s: %w", req.Topic, err)
537545

538546
return err
539547
}
540548

541-
// this is the ID of the application, it is supplied via runtime as "consumerID"
549+
// this is the ID of the application, it is supplied via runtime as "consumerID".
542550
var queueInfo *sqsQueueInfo
543551
queueInfo, err = s.getOrCreateQueue(s.metadata.sqsQueueName)
544552
if err != nil {
545-
s.logger.Errorf("error retrieving SQS queue: %v", err)
553+
s.logger.Errorf("error retrieving SQS queue: %w", err)
546554

547555
return err
548556
}
549557

558+
// only after a SQS queue and SNS topic had been setup, we restrict the SendMessage action to SNS as sole source
559+
// to prevent anyone but SNS to publish message to SQS.
560+
err = s.restrictQueuePublishPolicyToOnlySNS(queueInfo, topicArn)
561+
if err != nil {
562+
s.logger.Errorf("error setting sns-sqs subscription policy: %w", err)
563+
564+
return err
565+
}
566+
567+
// apply the dead letters queue attributes to the current queue.
550568
var deadLettersQueueInfo *sqsQueueInfo
551569
if len(s.metadata.sqsDeadLettersQueueName) > 0 {
552570
var derr error
553571
deadLettersQueueInfo, derr = s.createDeadLettersQueue()
554572
if derr != nil {
555-
s.logger.Errorf("error creating dead-letter queue: %v", derr)
573+
s.logger.Errorf("error creating dead-letter queue: %w", derr)
556574

557575
return derr
558576
}
559577

560578
var sqsSetQueueAttributesInput *sqs.SetQueueAttributesInput
561579
sqsSetQueueAttributesInput, derr = s.createQueueAttributesWithDeadLetters(queueInfo, deadLettersQueueInfo)
562580
if derr != nil {
563-
s.logger.Errorf("error creatubg queue attributes for dead-letter queue: %v", derr)
581+
s.logger.Errorf("error creatubg queue attributes for dead-letter queue: %w", derr)
564582

565583
return derr
566584
}
567585
_, derr = s.sqsClient.SetQueueAttributes(sqsSetQueueAttributesInput)
568586
if derr != nil {
569-
s.logger.Errorf("error updating queue attributes with dead-letter queue: %v", derr)
587+
wrappedErr := fmt.Errorf("error updating queue attributes with dead-letter queue: %w", derr)
588+
s.logger.Error(wrappedErr)
570589

571-
return derr
590+
return wrappedErr
572591
}
573592
}
574593

575-
// apply the dead letters queue attributes to the current queue
576-
577-
// subscription creation is idempotent. Subscriptions are unique by topic/queue
594+
// subscription creation is idempotent. Subscriptions are unique by topic/queue.
578595
subscribeOutput, err := s.snsClient.Subscribe(&sns.SubscribeInput{
579596
Attributes: nil,
580597
Endpoint: &queueInfo.arn, // create SQS queue per subscription
@@ -583,9 +600,10 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
583600
TopicArn: &topicArn,
584601
})
585602
if err != nil {
586-
s.logger.Errorf("error subscribing to topic %s: %v", req.Topic, err)
603+
wrappedErr := fmt.Errorf("error subscribing to topic %s: %w", req.Topic, err)
604+
s.logger.Error(wrappedErr)
587605

588-
return err
606+
return wrappedErr
589607
}
590608

591609
s.subscriptions = append(s.subscriptions, subscribeOutput.SubscriptionArn)

0 commit comments

Comments
 (0)