Skip to content

Commit

Permalink
Add support for using suspending methods with @MqttSubscribe (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
cmdjulian authored Apr 14, 2024
1 parent fbcaff6 commit a9accc7
Show file tree
Hide file tree
Showing 8 changed files with 346 additions and 30 deletions.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ class MqttTimeoutConfigurer : MqttClientConfigurer {

### Annotation based

The `MqttSubscribe` annotation is scanned on application start and receives messages on the given topic.
The `MqttSubscribe` annotation is scanned on application start and receives messages on the given topic.
It additionally supports kotlin suspend functions. Those functions are run inside the mqtt client thread pool.

```kotlin
import com.hivemq.client.mqtt.datatypes.MqttQos.AT_LEAST_ONCE
Expand Down Expand Up @@ -109,6 +110,12 @@ class TestConsumer {
fun subscribe() {
println("Something happened")
}

// Suspending function
@MqttSubscribe(topic = "/home/ping", qos = AT_LEAST_ONCE)
suspend fun suspending() {
println("Something happened suspending")
}
}
```

Expand Down
2 changes: 2 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ repositories {
dependencies {
implementation("org.jetbrains.kotlin:kotlin-reflect")

implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core")

implementation(platform("org.springframework.boot:spring-boot-dependencies:3.2.3"))
implementation("org.springframework.boot:spring-boot")
implementation("org.springframework.boot:spring-boot-autoconfigure")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.hivemq.client.mqtt.mqtt5.Mqtt5ClientBuilder
* Interface to enable more advanced configuration for the [Mqtt3ClientBuilder] than what is possible with the
* properties.
*/
interface Mqtt3ClientConfigurer {
fun interface Mqtt3ClientConfigurer {

/**
* To be implemented by consumers. Can perform any configuration on the given [builder] in place.
Expand All @@ -20,7 +20,7 @@ interface Mqtt3ClientConfigurer {
* Interface to enable more advanced configuration for the [Mqtt5ClientBuilder] than what is possible with the
* properties.
*/
interface Mqtt5ClientConfigurer {
fun interface Mqtt5ClientConfigurer {

/**
* To be implemented by consumers. Can perform any configuration on the given [builder] in place.
Expand Down
77 changes: 50 additions & 27 deletions src/main/kotlin/de/smartsquare/starter/mqtt/MqttHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@ package de.smartsquare.starter.mqtt
import com.fasterxml.jackson.core.JacksonException
import com.fasterxml.jackson.databind.JsonMappingException
import com.hivemq.client.mqtt.datatypes.MqttTopic
import de.smartsquare.starter.mqtt.MqttSubscriberCollector.ResolvedMqttSubscriber
import de.smartsquare.starter.mqtt.MqttHandler.AnnotatedMethodDelegate
import kotlinx.coroutines.runBlocking
import org.slf4j.LoggerFactory
import java.lang.reflect.InvocationTargetException
import java.lang.invoke.MethodHandles
import java.util.concurrent.ConcurrentHashMap
import kotlin.reflect.full.callSuspend
import kotlin.reflect.full.valueParameters
import kotlin.reflect.jvm.jvmErasure
import kotlin.reflect.jvm.kotlinFunction

/**
* Class for consuming and forwarding messages to the correct subscriber.
Expand All @@ -18,44 +23,34 @@ class MqttHandler(
) {

private val logger = LoggerFactory.getLogger(this::class.java)
private val subscriberCache = ConcurrentHashMap<MqttTopic, MqttSubscriberReference>(collector.subscribers.size)

private val subscriberCache = ConcurrentHashMap<MqttTopic, ResolvedMqttSubscriber>(collector.subscribers.size)
private data class MqttSubscriberReference(
val subscriber: AnnotatedMethodDelegate,
val parameterTypes: List<Class<*>>,
)

private fun interface AnnotatedMethodDelegate {
fun invoke(vararg args: Any)
}

/**
* Handles a single [message]. The topic of the message is used to determine the correct subscriber which is then
* invoked with parameters produced by the [MqttMessageAdapter].
*/
fun handle(message: MqttPublishContainer) {
val topic = message.topic

val (topic, payload) = message
if (logger.isTraceEnabled) {
logger.trace("Received mqtt message on topic [$topic] with payload ${message.payload}")
logger.trace("Received mqtt message on topic [$topic] with payload $payload")
}

val subscriber = subscriberCache
.getOrPut(topic) { collector.subscribers.find { it.topic.matches(topic) } }
?: error("No subscriber found for topic $topic")

val (subscriber, parameterTypes) = getSubscriber(topic)
try {
val parameters = subscriber.method.parameterTypes
.map { adapter.adapt(message, it) }
.toTypedArray()

subscriber.method.invoke(subscriber.bean, *parameters)
} catch (e: InvocationTargetException) {
messageErrorHandler.handle(
MqttMessageException(
topic,
message.payload,
"Error while handling mqtt message on topic [$topic]",
e,
),
)
subscriber.invoke(*Array(parameterTypes.size) { adapter.adapt(message, parameterTypes[it]) })
} catch (e: JsonMappingException) {
messageErrorHandler.handle(
MqttMessageException(
topic,
message.payload,
payload,
"Error while handling mqtt message on topic [$topic]: Failed to map payload to target class",
e,
),
Expand All @@ -64,11 +59,39 @@ class MqttHandler(
messageErrorHandler.handle(
MqttMessageException(
topic,
message.payload,
payload,
"Error while handling mqtt message on topic [$topic]: Failed to parse payload",
e,
),
)
} catch (e: Exception) {
messageErrorHandler.handle(
MqttMessageException(topic, payload, "Error while handling mqtt message on topic [$topic]", e),
)
}
}

/**
* Returns the subscriber for the given [topic].
* If no subscriber is found, an error is thrown.
* The subscriber is cached for performance reasons.
*
* If the function is a suspend function, it is wrapped in a suspend call. For normal functions, a method handle is
* created and cached.
*/
private fun getSubscriber(topic: MqttTopic): MqttSubscriberReference = subscriberCache.getOrPut(topic) {
val subscriber = collector.subscribers.find { it.topic.matches(topic) }
?: error("No subscriber found for topic $topic")
val kFunction = subscriber.method.kotlinFunction
val parameterTypes = kFunction?.valueParameters?.map { it.type.jvmErasure.java }
?: subscriber.method.parameterTypes.toList()
val delegate = if (kFunction?.isSuspend == true) {
AnnotatedMethodDelegate { args -> runBlocking { kFunction.callSuspend(subscriber.bean, *args) } }
} else {
val handle = MethodHandles.publicLookup().unreflect(subscriber.method)
AnnotatedMethodDelegate { args -> handle.invokeWithArguments(subscriber.bean, *args) }
}

MqttSubscriberReference(delegate, parameterTypes)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,16 @@ import com.hivemq.client.mqtt.mqtt5.message.publish.Mqtt5Publish
*/
sealed interface MqttPublishContainer {
val topic: MqttTopic

val payload: ByteArray

val value: Any

@JvmSynthetic
operator fun component1(): MqttTopic = topic

@JvmSynthetic
operator fun component2(): ByteArray = payload
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import de.smartsquare.starter.mqtt.Mqtt3AutoConfigurationTest.MultipleSubscriber
import de.smartsquare.starter.mqtt.Mqtt3AutoConfigurationTest.ObjectSubscriber
import de.smartsquare.starter.mqtt.Mqtt3AutoConfigurationTest.PublishSubscriber
import de.smartsquare.starter.mqtt.Mqtt3AutoConfigurationTest.StringSubscriber
import de.smartsquare.starter.mqtt.Mqtt3AutoConfigurationTest.SuspendSubscriber
import org.amshove.kluent.shouldBeEqualTo
import org.amshove.kluent.shouldBeTrue
import org.amshove.kluent.shouldNotBeNull
Expand All @@ -28,6 +29,7 @@ import org.springframework.stereotype.Component
StringSubscriber::class,
ObjectSubscriber::class,
PublishSubscriber::class,
SuspendSubscriber::class,
EmptySubscriber::class,
MultipleSubscriber::class,
],
Expand All @@ -52,6 +54,9 @@ class Mqtt3AutoConfigurationTest {
@Autowired
private lateinit var publishSubscriber: PublishSubscriber

@Autowired
private lateinit var suspendSubscriber: SuspendSubscriber

@Autowired
private lateinit var emptySubscriber: EmptySubscriber

Expand Down Expand Up @@ -102,6 +107,20 @@ class Mqtt3AutoConfigurationTest {
}
}

@Test
fun `receives publish message from suspend function`() {
val publish = Mqtt3Publish.builder()
.topic("suspend")
.payload("test".toByteArray())
.qos(MqttQos.EXACTLY_ONCE).build()

client.toBlocking().publish(publish)

await untilAssertedKluent {
suspendSubscriber.receivedPayload shouldBeEqualTo publish
}
}

@Test
fun `receives object message`() {
// language=json
Expand Down Expand Up @@ -237,6 +256,18 @@ class Mqtt3AutoConfigurationTest {
}
}

@Component
class SuspendSubscriber {

val receivedPayload get() = _receivedPayload
private var _receivedPayload: Mqtt3Publish? = null

@MqttSubscribe(topic = "suspend", qos = MqttQos.EXACTLY_ONCE)
suspend fun onMessage(payload: Mqtt3Publish) {
_receivedPayload = payload
}
}

@Component
class ObjectSubscriber {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.hivemq.client.mqtt.mqtt5.message.publish.Mqtt5Publish
import de.smartsquare.starter.mqtt.Mqtt5AutoConfigurationTest.ErrorSubscriber
import de.smartsquare.starter.mqtt.Mqtt5AutoConfigurationTest.IntSubscriber
import de.smartsquare.starter.mqtt.Mqtt5AutoConfigurationTest.PublishSubscriber
import de.smartsquare.starter.mqtt.Mqtt5AutoConfigurationTest.SuspendSubscriber
import org.amshove.kluent.shouldBeEqualTo
import org.awaitility.kotlin.await
import org.junit.jupiter.api.Test
Expand All @@ -21,6 +22,7 @@ import org.springframework.test.context.TestPropertySource
MqttAutoConfiguration::class,
IntSubscriber::class,
PublishSubscriber::class,
SuspendSubscriber::class,
ErrorSubscriber::class,
],
)
Expand All @@ -39,6 +41,9 @@ class Mqtt5AutoConfigurationTest {
@Autowired
private lateinit var publishSubscriber: PublishSubscriber

@Autowired
private lateinit var suspendSubscriber: SuspendSubscriber

@Autowired
private lateinit var errorSubscriber: ErrorSubscriber

Expand Down Expand Up @@ -71,6 +76,20 @@ class Mqtt5AutoConfigurationTest {
}
}

@Test
fun `receives publish message from suspend function`() {
val publish = Mqtt5Publish.builder()
.topic("suspend")
.payload("test".toByteArray())
.qos(MqttQos.EXACTLY_ONCE).build()

client.toBlocking().publish(publish)

await untilAssertedKluent {
suspendSubscriber.receivedPayload shouldBeEqualTo publish
}
}

@Test
fun `does not crash completely when subscriber throws exception`() {
client.toBlocking()
Expand Down Expand Up @@ -127,6 +146,18 @@ class Mqtt5AutoConfigurationTest {
}
}

@Component
class SuspendSubscriber {

val receivedPayload get() = _receivedPayload
private var _receivedPayload: Mqtt5Publish? = null

@MqttSubscribe(topic = "suspend", qos = MqttQos.EXACTLY_ONCE)
suspend fun onMessage(payload: Mqtt5Publish) {
_receivedPayload = payload
}
}

@Component
class ErrorSubscriber {

Expand Down
Loading

0 comments on commit a9accc7

Please sign in to comment.