Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Capturing the AWS Account ID during SQS Subscribe #24

Merged
merged 1 commit into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import com.typesafe.config.ConfigFactory
import io.vertx.ext.web.RoutingContext
import org.apache.logging.log4j.LogManager
import org.apache.logging.log4j.Logger
import java.net.URI
import java.net.URL
import java.util.*
import java.util.regex.Pattern
Expand Down Expand Up @@ -60,17 +61,7 @@ val subscribeRoute: (RoutingContext) -> Unit = route@{ ctx: RoutingContext ->
return@route
}

val subscriptionEndpoint = if (protocol == "sqs" && (endpoint?.startsWith("http") == true || endpoint?.startsWith("https") == true)) {
val url = URL(endpoint)
val endpointProtocol = url.protocol
val endpointHost = url.host
val endpointPort = url.port
val endpointPath = url.path
val queueName = endpointPath.split("/").last()
buildSqsEndpoint(queueName, endpointProtocol, endpointHost, endpointPort)
} else {
endpoint
}
val subscriptionEndpoint = buildSubscriptionEndpointData(protocol, endpoint)

val subscriptions = getSubscriptionsMap(vertx)!!
val owner = getAwsAccountId(config = ConfigFactory.load())
Expand Down Expand Up @@ -103,17 +94,53 @@ val subscribeRoute: (RoutingContext) -> Unit = route@{ ctx: RoutingContext ->
)
}

private fun buildSubscriptionEndpointData(
protocol: String?,
endpoint: String?
): String {
return if (protocol == "sqs" && (endpoint?.startsWith("http") == true || endpoint?.startsWith("https") == true)) {
val url = URI(endpoint)
val endpointProtocol = url.scheme
val endpointHost = url.host
val endpointPort = url.port
val endpointPath = url.path
val pathParts = endpointPath.split("/")
val queueName = pathParts.last()
val accountId = pathParts[pathParts.size - 2]
val httpSqsEndpoint = buildSqsEndpoint(queueName, endpointProtocol, endpointHost, endpointPort, accountId)
httpSqsEndpoint
} else {
endpoint!!
}
}

private fun buildSqsEndpoint(
queueName: String,
endpointProtocol: String?,
endpointHost: String?,
endpointPort: Int
endpointPort: Int,
accountId: String?
): String {
val hostAndPort = if (endpointPort > -1) {
"$endpointProtocol://$endpointHost:$endpointPort"
} else {
"$endpointProtocol://$endpointHost"
}
val queryParams =
"?accessKey=xxx&secretKey=xxx&region=us-east-1&trustAllCertificates=true&overrideEndpoint=true&uriEndpointOverride="
mapOf(
"accessKey" to "xxx",
"secretKey" to "xxx",
"region" to "us-east-1",
"trustAllCertificates" to "true",
"overrideEndpoint" to "true",
"queueOwnerAWSAccountId" to (accountId ?: getAwsAccountId(ConfigFactory.load())),
"uriEndpointOverride" to hostAndPort,
)
.map { (key, value) -> "$key=$value" }
.joinToString("&", "?")
return if (endpointPort > -1) {
"aws2-sqs://$queueName$queryParams$endpointProtocol://$endpointHost:$endpointPort"
"aws2-sqs://$queueName$queryParams"
} else {
"aws2-sqs://$queueName$queryParams$endpointProtocol://$endpointHost"
"aws2-sqs://$queueName$queryParams"
}
}
9 changes: 7 additions & 2 deletions src/test/kotlin/com/jameskbride/localsns/BaseTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,13 @@ open class BaseTest {
return "aws2-sqs://$name?accessKey=xxx&secretKey=xxx&region=us-east-1&trustAllCertificates=true&overrideEndpoint=true&uriEndpointOverride=http://localhost:9324/000000000000/$name&messageAttributeNames=first,second"
}

fun createHttpSqsEndpoint(name: String): String {
return "http://localhost:9324/000000000000/$name"
fun createHttpSqsEndpoint(name: String, protocolAndHost: String, port: String? = null): String {
val portString = if (port != null) {
":$port"
} else {
""
}
return "$protocolAndHost$portString/000000000000/$name"
}

fun createCamelHttpEndpoint(uri: String, method: String = "POST"): String {
Expand Down
113 changes: 77 additions & 36 deletions src/test/kotlin/com/jameskbride/localsns/SubscribeRouteTest.kt
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package com.jameskbride.localsns

import com.jameskbride.localsns.models.Configuration
import com.jameskbride.localsns.models.Subscription
import com.jameskbride.localsns.models.Topic
import com.jameskbride.localsns.verticles.DatabaseVerticle
import com.jameskbride.localsns.verticles.MainVerticle
import com.typesafe.config.Config
import com.typesafe.config.ConfigFactory
import io.vertx.core.CompositeFuture
import io.vertx.core.Vertx
Expand All @@ -15,6 +18,7 @@ import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Tag
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
import java.net.URI

@ExtendWith(VertxExtension::class)
class SubscribeRouteTest : BaseTest() {
Expand Down Expand Up @@ -78,23 +82,14 @@ class SubscribeRouteTest : BaseTest() {
val camelSqsEndpoint = createCamelSqsEndpoint(queueName)
val config = ConfigFactory.load()
vertx.eventBus().consumer<String>("configChangeComplete") {
vertx.fileSystem()
.readFile(getDbOutputPath(config))
.onComplete { result ->
val configFile = result.result()
val jsonConfig = JsonObject(configFile)

val configuration = jsonConfig.mapTo(Configuration::class.java)
assertEquals(configuration.version, 1)
assertTrue(configuration.topics.contains(topic))
val foundSubscription = configuration.subscriptions
.find { it.topicArn == topic.arn && it.protocol == "sqs" && it.endpoint == camelSqsEndpoint }
if (foundSubscription == null) {
testContext.failNow(IllegalStateException("Subscription not found"))
}

testContext.completeNow()
}
waitOnSubscription(vertx, testContext, config, topic) {
it.topicArn == topic.arn && it.protocol == "sqs" && camelEndpointMatches(
it.endpoint,
queueName,
"000000000000",
"http://localhost:9324"
)
}
}

val response = subscribe(topic.arn, camelSqsEndpoint, "sqs")
Expand All @@ -112,35 +107,81 @@ class SubscribeRouteTest : BaseTest() {
val topic = createTopicModel("topic1")

val queueName = "queue1"
val response = subscribe(topic.arn, createHttpSqsEndpoint(queueName), "sqs")
val response = subscribe(topic.arn, createHttpSqsEndpoint(queueName, "http://localhost:9324"), "sqs")
val config = ConfigFactory.load()
vertx.eventBus().consumer<String>("configChangeComplete") {
vertx.fileSystem()
.readFile(getDbOutputPath(config))
.onComplete { result ->
val configFile = result.result()
val jsonConfig = JsonObject(configFile)

val configuration = jsonConfig.mapTo(Configuration::class.java)
assertEquals(configuration.version, 1)
assertTrue(configuration.topics.contains(topic))
val expectedEndpoint =
"aws2-sqs://$queueName?accessKey=xxx&secretKey=xxx&region=us-east-1&trustAllCertificates=true&overrideEndpoint=true&uriEndpointOverride=http://localhost:9324"
val foundSubscription = configuration.subscriptions
.find { it.topicArn == topic.arn && it.protocol == "sqs" && it.endpoint == expectedEndpoint }
if (foundSubscription == null) {
testContext.failNow(IllegalStateException("Subscription not found"))
}
waitOnSubscription(vertx, testContext, config, topic) {
it.topicArn == topic.arn && it.protocol == "sqs" && camelEndpointMatches(
it.endpoint,
queueName,
"000000000000",
"http://localhost:9324"
)
}
}

testContext.completeNow()
}
val subscriptionArn = getSubscriptionArnFromResponse(response)
assertEquals(200, response.statusCode)
assertTrue(subscriptionArn.isNotEmpty())
}

@Test
@Tag("skipForCI")
fun `it creates a camel-compliant sqs endpoint subscription when subscribing to an http sqs queue with no port`(vertx: Vertx, testContext: VertxTestContext) {
val topic = createTopicModel("topic1")

val queueName = "queue1"
val response = subscribe(topic.arn, createHttpSqsEndpoint(queueName, "https://sqs.us-east-1.amazonaws.com"), "sqs")
val config = ConfigFactory.load()
vertx.eventBus().consumer<String>("configChangeComplete") {
waitOnSubscription(vertx, testContext, config, topic) {
it.topicArn == topic.arn && it.protocol == "sqs" && camelEndpointMatches(
it.endpoint,
queueName,
"000000000000",
"https://sqs.us-east-1.amazonaws.com"
)
}
}

val subscriptionArn = getSubscriptionArnFromResponse(response)
assertEquals(200, response.statusCode)
assertTrue(subscriptionArn.isNotEmpty())
}

private fun waitOnSubscription(
vertx: Vertx,
testContext: VertxTestContext,
config: Config,
topic: Topic,
subscriptionPredicate: (Subscription) -> Boolean
) {
vertx.fileSystem()
.readFile(getDbOutputPath(config))
.onComplete { result ->
val configFile = result.result()
val jsonConfig = JsonObject(configFile)

val configuration = jsonConfig.mapTo(Configuration::class.java)
assertEquals(configuration.version, 1)
assertTrue(configuration.topics.contains(topic))
val foundSubscription = configuration.subscriptions.find(subscriptionPredicate)
if (foundSubscription == null) {
testContext.failNow(IllegalStateException("Subscription not found"))
}

testContext.completeNow()
}
}

private fun camelEndpointMatches(endpoint: String?, queueName: String, accountId: String, uriEndpointOverride: String): Boolean {
val uri = URI(endpoint)
return uri.scheme == "aws2-sqs"
&& uri.host == queueName
&& uri.query.contains("queueOwnerAWSAccountId=$accountId")
&& uri.query.contains("uriEndpointOverride=$uriEndpointOverride")
}

@Test
fun `it indicates a db change`(vertx: Vertx, testContext: VertxTestContext) {
val topic = createTopicModel("topic1")
Expand Down
Loading