Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Raw Message Delivery #5

Merged
merged 12 commits into from
Nov 21, 2023
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
[![Docker Pulls](https://img.shields.io/docker/pulls/jameskbride/local-sns.svg?maxAge=2592000)](https://hub.docker.com/r/jameskbride/localsns/)

# Local SNS
Fake Amazon Simple Notification Service (SNS) for local development. Supports:
- Create/List/Delete topics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.jameskbride.localsns.models

import com.fasterxml.jackson.annotation.JsonIgnore
import java.io.Serializable
import java.net.URLDecoder

data class Subscription(
val arn: String,
Expand All @@ -15,4 +16,12 @@ data class Subscription(
val namePattern = """([\w+_-]{1,256})"""
val arnPattern = """([\w+_:-]{1,512})"""
}

@JsonIgnore fun decodedEndpointUrl():String {
return URLDecoder.decode(endpoint, "UTF-8")
}

@JsonIgnore fun isRawMessageDelivery(): Boolean {
return subscriptionAttributes.getOrDefault("RawMessageDelivery", "false") == "true"
}
}
136 changes: 116 additions & 20 deletions src/main/kotlin/com/jameskbride/localsns/routes/topics/publishRoute.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import org.apache.camel.ProducerTemplate
import org.apache.camel.impl.DefaultCamelContext
import org.apache.logging.log4j.LogManager
import org.apache.logging.log4j.Logger
import java.net.URLDecoder
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter
import java.util.*
Expand Down Expand Up @@ -114,7 +113,7 @@ private fun publishJsonStructure(
messages["default"]
}
logger.info("Messages to publish: $messageToPublish")
publishMessage(producerTemplate, subscription, messageToPublish as String, logger, messageAttributes)
publishMessage(subscription, messageToPublish as String, messageAttributes, producerTemplate, logger)
} catch (e: Exception) {
logger.error("An error occurred when publishing to: ${subscription.endpoint}", e)
}
Expand All @@ -137,7 +136,7 @@ private fun publishBasicMessage(
subscriptions.forEach { subscription ->
try {
logger.info("Message to publish: $message")
publishMessage(producerTemplate, subscription, message, logger, messageAttributes)
publishMessage(subscription, message, messageAttributes, producerTemplate, logger)
} catch (e: Exception) {
logger.error("An error occurred when publishing to: ${subscription.endpoint}", e)
}
Expand All @@ -149,43 +148,140 @@ fun getTopicArn(topicArn: String?, targetArn: String?): String? {
}

private fun publishMessage(
producer: ProducerTemplate,
subscription: Subscription,
message: String,
logger: Logger,
messageAttributes: Map<String, String>
messageAttributes: Map<String, String>,
producer: ProducerTemplate,
logger: Logger
) {
val decodedUrl = URLDecoder.decode(subscription.endpoint, "UTF-8")
val headers = messageAttributes.map { it.key to it.value }.toMap() +
mapOf(
"x-amz-sns-message-type" to "Notification",
"x-amz-sns-message-id" to UUID.randomUUID().toString(),
"x-amz-sns-subscription-arn" to subscription.arn,
"x-amz-sns-topic-arn" to subscription.topicArn
)
val timestamp = LocalDateTime.now()
val snsMessage = createSnsMessage(timestamp, message, subscription)
val gson = Gson()
when (subscription.protocol) {
"lambda" -> {
val record = LambdaRecord("aws:sns", subscription.arn, 1.0, snsMessage)
val event = LambdaEvent(listOf(record))
val messageToPublish = gson.toJson(event)
producer.asyncRequestBodyAndHeaders(decodedUrl, messageToPublish, headers + mapOf("Content-Type" to "application/json"))
.exceptionally { logger.error("Error publishing message $message, to subscription: $subscription", it) }
publishToLambda(subscription, message, headers, producer, logger)
}
"http" -> {
publishToHttp(subscription, message, headers, producer, logger)
}
"https" -> {
publishToHttp(subscription, message, headers, producer, logger)
}
"slack" -> {
publishToSlack(subscription, message, headers, producer, logger)
}
"sqs" -> {
publishToSqs(subscription, message, headers, producer, logger)
}
else -> {
val messageToPublish = gson.toJson(snsMessage)
producer.asyncRequestBodyAndHeaders(decodedUrl, messageToPublish, headers)
.exceptionally { logger.error("Error publishing message $messageToPublish, to subscription: $subscription", it) }
publishAllowingRawMessage(subscription, message, headers, producer, logger)
}
}
}

private fun publishToSqs(
subscription: Subscription,
message: String,
headers: Map<String, String>,
producer: ProducerTemplate,
logger: Logger
) {
publishAllowingRawMessage(subscription, message, headers, producer, logger)
}

private fun publishAllowingRawMessage(
subscription: Subscription,
message: String,
headers: Map<String, String>,
producer: ProducerTemplate,
logger: Logger
) {
val messageToPublish = if (subscription.isRawMessageDelivery()) {
message
} else {
val timestamp = LocalDateTime.now()
val snsMessage = createSnsMessage(subscription, message, timestamp)
val gson = Gson()
gson.toJson(snsMessage)
}
publish(subscription, messageToPublish, headers, producer, logger)
}

private fun publishToLambda(
subscription: Subscription,
message: String,
headers: Map<String, String>,
producer: ProducerTemplate,
logger: Logger
) {
val timestamp = LocalDateTime.now()
val snsMessage = createSnsMessage(subscription, message, timestamp)
val gson = Gson()
val record = LambdaRecord("aws:sns", subscription.arn, 1.0, snsMessage)
val event = LambdaEvent(listOf(record))
val messageToPublish = gson.toJson(event)
producer.asyncRequestBodyAndHeaders(
subscription.decodedEndpointUrl(),
messageToPublish,
headers + mapOf("Content-Type" to "application/json")
)
.exceptionally { logger.error("Error publishing message $message, to subscription: $subscription", it) }
}

private fun publishToHttp(
subscription: Subscription,
message: String,
headers: Map<String, String>,
producer: ProducerTemplate,
logger: Logger
) {
val timestamp = LocalDateTime.now()
val snsMessage = createSnsMessage(subscription, message, timestamp)
val gson = Gson()
val httpHeaders = if (subscription.isRawMessageDelivery()) {
headers + mapOf("x-amz-sns-rawdelivery" to "true")
} else {
headers
}

val messageToPublish = if (subscription.isRawMessageDelivery()) {
message
} else {
gson.toJson(snsMessage)
}

publish(subscription, messageToPublish, httpHeaders, producer, logger)
}

private fun publishToSlack(
subscription: Subscription,
message: String,
headers: Map<String, String>,
producer: ProducerTemplate,
logger: Logger
) {
publish(subscription, message, headers, producer, logger)
}

private fun publish(
subscription: Subscription,
message: String,
headers: Map<String, String>,
producer: ProducerTemplate,
logger: Logger
) {
producer.asyncRequestBodyAndHeaders(subscription.decodedEndpointUrl(), message, headers)
.exceptionally { logger.error("Error publishing message $message, to subscription: $subscription", it) }
}

private fun createSnsMessage(
timestamp: LocalDateTime?,
subscription: Subscription,
message: String,
subscription: Subscription
timestamp: LocalDateTime?
): SnsMessage {
val formattedTimestamp = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'").format(timestamp)
return SnsMessage(
Expand Down
2 changes: 1 addition & 1 deletion src/test/kotlin/com/jameskbride/localsns/BaseTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ open class BaseTest {
protected fun createValidArn(resourceName: String) =
"arn:aws:sns:us-east-1:123456789012:${resourceName}"

fun createEndpoint(name: String): String {
fun createSqsEndpoint(name: String): String {
return "aws2-sqs://$name?accessKey=xxx&secretKey=xxx&region=us-east-1&trustAllCertificates=true&overrideEndpoint=true&uriEndpointOverride=http://localhost:9324/000000000000/$name&messageAttributeNames=first,second"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class DatabaseVerticleTest: BaseTest() {
arn = createValidArn("subscription1"),
owner="owner",
protocol="sqs",
endpoint=createEndpoint("queue1")
endpoint=createSqsEndpoint("queue1")
)

val topics = getTopicsMap(vertx)!!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class GetSubscriptionAttributesRouteTest: BaseTest() {
@Test
fun `it returns an error when SubscriptionArn is invalid`(testContext: VertxTestContext) {
val topic = createTopicModel("topic1")
subscribe(topicArn = topic.arn, createEndpoint("queue1"), "sqs")
subscribe(topicArn = topic.arn, createSqsEndpoint("queue1"), "sqs")
val response = getSubscriptionAttributes("bad arn")

assertEquals(400, response.statusCode)
Expand All @@ -48,7 +48,7 @@ class GetSubscriptionAttributesRouteTest: BaseTest() {
@Test
fun `it can return default attribute values`(testContext: VertxTestContext) {
val topic = createTopicModel("topic1")
val endpoint = createEndpoint("queue1")
val endpoint = createSqsEndpoint("queue1")
val subscriptionArn = getSubscriptionArnFromResponse((subscribe(topicArn = topic.arn, endpoint, "sqs")))

val response = getSubscriptionAttributes(subscriptionArn)
Expand All @@ -73,7 +73,7 @@ class GetSubscriptionAttributesRouteTest: BaseTest() {
@Test
fun `it can return overridden attributes`(testContext: VertxTestContext) {
val topic = createTopicModel("topic1")
val endpoint = createEndpoint("queue1")
val endpoint = createSqsEndpoint("queue1")
val subscriptionArn = getSubscriptionArnFromResponse((subscribe(topicArn = topic.arn, endpoint, "sqs")))

setSubscriptionAttributes(subscriptionArn, "RawMessageDelivery", "true")
Expand All @@ -90,7 +90,7 @@ class GetSubscriptionAttributesRouteTest: BaseTest() {
@Test
fun `it can return arbitrary attributes`(testContext: VertxTestContext) {
val topic = createTopicModel("topic1")
val endpoint = createEndpoint("queue1")
val endpoint = createSqsEndpoint("queue1")
val subscriptionArn = getSubscriptionArnFromResponse((subscribe(topicArn = topic.arn, endpoint, "sqs")))

setSubscriptionAttributes(subscriptionArn, "status", "sending")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ class ListSubscriptionsByTopicRouteTest: BaseTest() {
@Test
fun `it returns success when there are subscriptions for a topic`(testContext: VertxTestContext) {
val topic = createTopicModel("topic1")
val endpoint1 = createEndpoint("queue1")
val endpoint1 = createSqsEndpoint("queue1")
val subscribeResponse1 = subscribe(topic.arn, endpoint1, "sqs")
val subscription1Arn = getSubscriptionArnFromResponse(subscribeResponse1)

val topic2 = createTopicModel("topic2")
val endpoint2 = createEndpoint("queue2")
val endpoint2 = createSqsEndpoint("queue2")
val subscribeResponse2 = subscribe(topic2.arn, endpoint2, "lambda")
val subscription2Arn = getSubscriptionArnFromResponse(subscribeResponse2)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ class ListSubscriptionsRouteTest: BaseTest() {
@Test
fun `it returns subscriptions when they exist `(testContext: VertxTestContext) {
val topic = createTopicModel("topic1")
val endpoint1 = createEndpoint("queue1")
val endpoint1 = createSqsEndpoint("queue1")
val subscribeResponse1 = subscribe(topic.arn, endpoint1, "sqs")
val subscription1Arn = getSubscriptionArnFromResponse(subscribeResponse1)

val topic2 = createTopicModel("topic2")
val endpoint2 = createEndpoint("queue2")
val endpoint2 = createSqsEndpoint("queue2")
val subscribeResponse2 = subscribe(topic2.arn, endpoint2, "lambda")
val subscription2Arn = getSubscriptionArnFromResponse(subscribeResponse2)

Expand Down
Loading
Loading