Skip to content

Commit

Permalink
Mqtt5 Websocket (#260)
Browse files Browse the repository at this point in the history
Co-authored-by: Steve Kim <sbstevek@amazon.com>
Co-authored-by: Steve Kim <86316075+sbSteveK@users.noreply.github.com>
  • Loading branch information
3 people authored May 3, 2024
1 parent a4dabb3 commit 0d66f15
Show file tree
Hide file tree
Showing 6 changed files with 378 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ extension CredentialsProvider.Source {
/// - Throws: CommonRuntimeError.crtError
public static func `defaultChain`(bootstrap: ClientBootstrap,
fileBasedConfiguration: FileBasedConfiguration,
tlsContext: TLSContext? = nil,
shutdownCallback: ShutdownCallback? = nil) -> Self {
Self {
let shutdownCallbackCore = ShutdownCallbackCore(shutdownCallback)
Expand All @@ -294,6 +295,7 @@ extension CredentialsProvider.Source {
chainDefaultOptions.bootstrap = bootstrap.rawValue
chainDefaultOptions.profile_collection_cached = fileBasedConfiguration.rawValue
chainDefaultOptions.shutdown_options = shutdownCallbackCore.getRetainedCredentialProviderShutdownOptions()
chainDefaultOptions.tls_ctx = tlsContext?.rawValue

guard let provider = aws_credentials_provider_new_chain_default(allocator.rawValue,
&chainDefaultOptions)
Expand Down
8 changes: 8 additions & 0 deletions Source/AwsCommonRuntimeKit/http/HTTPRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ public class HTTPRequest: HTTPRequestBase {
self.body = body
addHeaders(headers: headers)
}

/// Internal helper init function to acquire a reference of native http request
init (nativeHttpMessage: OpaquePointer) {
super.init(rawValue: nativeHttpMessage)
// Acquire a refcount to keep the message alive until this object dies.
aws_http_message_acquire(self.rawValue)
}

}

/// Represents a single client request to be sent on a HTTP2 connection
Expand Down
3 changes: 2 additions & 1 deletion Source/AwsCommonRuntimeKit/mqtt/Mqtt5Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ public class Mqtt5Client {
onLifecycleEventAttemptingConnect: options.onLifecycleEventAttemptingConnectFn,
onLifecycleEventConnectionSuccess: options.onLifecycleEventConnectionSuccessFn,
onLifecycleEventConnectionFailure: options.onLifecycleEventConnectionFailureFn,
onLifecycleEventDisconnection: options.onLifecycleEventDisconnectionFn)
onLifecycleEventDisconnection: options.onLifecycleEventDisconnectionFn,
onWebsocketInterceptor: options.onWebsocketTransform)

guard let rawValue = (options.withCPointer(
userData: self.callbackCore.callbackUserData()) { optionsPointer in
Expand Down
1 change: 0 additions & 1 deletion Source/AwsCommonRuntimeKit/mqtt/Mqtt5Packets.swift
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ public class PublishPacket: CStruct {
count: publishView.user_property_count,
userPropertiesPointer: publishView.user_properties)


let publishPacket = PublishPacket(qos: qos,
topic: publishView.topic.toString(),
payload: payload,
Expand Down
51 changes: 48 additions & 3 deletions Source/AwsCommonRuntimeKit/mqtt/Mqtt5Types.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
/// SPDX-License-Identifier: Apache-2.0.
import Foundation
import AwsCMqtt
import AwsCHttp

// TODO this is temporary. We will replace this with aws-crt-swift error codes.

Check warning on line 7 in Source/AwsCommonRuntimeKit/mqtt/Mqtt5Types.swift

View workflow job for this annotation

GitHub Actions / lint

Todo Violation: TODOs should be resolved (this is temporary. We will rep...) (todo)
enum MqttError: Error {
Expand Down Expand Up @@ -688,6 +689,15 @@ public class LifecycleAttemptingConnectData { }
/// Defines signature of the Lifecycle Event Attempting Connect callback
public typealias OnLifecycleEventAttemptingConnect = (LifecycleAttemptingConnectData) -> Void

/// Callback for users to invoke upon completion of, presumably asynchronous, OnWebSocketHandshakeIntercept callback's initiated process.
public typealias OnWebSocketHandshakeInterceptComplete = (HTTPRequestBase, Int32) -> Void

/// Invoked during websocket handshake to give users opportunity to transform an http request for purposes
/// such as signing/authorization etc... Returning from this function does not continue the websocket
/// handshake since some work flows may be asynchronous. To accommodate that, onComplete must be invoked upon
/// completion of the signing process.
public typealias OnWebSocketHandshakeIntercept = (HTTPRequest, @escaping OnWebSocketHandshakeInterceptComplete) -> Void

/// Class containing results of a Connect Success Lifecycle Event.
public class LifecycleConnectionSuccessData {

Expand Down Expand Up @@ -1086,6 +1096,32 @@ private func MqttClientPublishRecievedEvents(
}
}

private func MqttClientWebsocketTransform(
_ rawHttpMessage: OpaquePointer?,
_ userData: UnsafeMutableRawPointer?,
_ completeFn: (@convention(c) (OpaquePointer?, Int32, UnsafeMutableRawPointer?) -> Void)?,
_ completeCtx: UnsafeMutableRawPointer?) {

let callbackCore = Unmanaged<MqttCallbackCore>.fromOpaque(userData!).takeUnretainedValue()

// validate the callback flag, if flag is false, return
callbackCore.rwlock.read {
if callbackCore.callbackFlag == false { return }

guard let rawHttpMessage else {
fatalError("Null HttpRequeset in websocket transform function.")
}
let httpRequest = HTTPRequest(nativeHttpMessage: rawHttpMessage)
@Sendable func signerTransform(request: HTTPRequestBase, errorCode: Int32) {
completeFn?(request.rawValue, errorCode, completeCtx)
}

if callbackCore.onWebsocketInterceptor != nil {
callbackCore.onWebsocketInterceptor!(httpRequest, signerTransform)
}
}
}

private func MqttClientTerminationCallback(_ userData: UnsafeMutableRawPointer?) {
// termination callback
print("[Mqtt5 Client Swift] TERMINATION CALLBACK")
Expand Down Expand Up @@ -1113,9 +1149,8 @@ public class MqttClientOptions: CStructWithUserData {
/// The (tunneling) HTTP proxy usage when establishing MQTT connections
public let httpProxyOptions: HTTPProxyOptions?

// TODO WebSocket implementation
/// This callback allows a custom transformation of the HTTP request that acts as the websocket handshake. Websockets will be used if this is set to a valid transformation callback. To use websockets but not perform a transformation, just set this as a trivial completion callback. If None, the connection will be made with direct MQTT.
/// public let websocketHandshakeTransform: Callable[[WebsocketHandshakeTransformArgs], None] = None
public let onWebsocketTransform: OnWebSocketHandshakeIntercept?

/// All configurable options with respect to the CONNECT packet sent by the client, including the will. These connect properties will be used for every connection attempt made by the client.
public let connectOptions: MqttConnectOptions?
Expand Down Expand Up @@ -1177,6 +1212,7 @@ public class MqttClientOptions: CStructWithUserData {
bootstrap: ClientBootstrap? = nil,
socketOptions: SocketOptions? = nil,
tlsCtx: TLSContext? = nil,
onWebsocketTransform: OnWebSocketHandshakeIntercept? = nil,
httpProxyOptions: HTTPProxyOptions? = nil,
connectOptions: MqttConnectOptions? = nil,
sessionBehavior: ClientSessionBehaviorType? = nil,
Expand Down Expand Up @@ -1216,6 +1252,7 @@ public class MqttClientOptions: CStructWithUserData {

self.socketOptions = socketOptions ?? SocketOptions()
self.tlsCtx = tlsCtx
self.onWebsocketTransform = onWebsocketTransform
self.httpProxyOptions = httpProxyOptions
self.connectOptions = connectOptions
self.sessionBehavior = sessionBehavior
Expand Down Expand Up @@ -1361,7 +1398,11 @@ public class MqttClientOptions: CStructWithUserData {
}
}

// TODO: SETUP lifecycle_event_handler and publish_received_handler
if self.onWebsocketTransform != nil {
raw_options.websocket_handshake_transform = MqttClientWebsocketTransform
raw_options.websocket_handshake_transform_user_data = _userData
}

raw_options.lifecycle_event_handler = MqttClientLifeycyleEvents
raw_options.lifecycle_event_handler_user_data = _userData
raw_options.publish_received_handler = MqttClientPublishRecievedEvents
Expand All @@ -1385,6 +1426,8 @@ class MqttCallbackCore {
let onLifecycleEventConnectionSuccess: OnLifecycleEventConnectionSuccess
let onLifecycleEventConnectionFailure: OnLifecycleEventConnectionFailure
let onLifecycleEventDisconnection: OnLifecycleEventDisconnection
// The websocket interceptor could be nil if the websocket is not in use
let onWebsocketInterceptor: OnWebSocketHandshakeIntercept?

let rwlock = ReadWriteLock()
var callbackFlag = true
Expand All @@ -1395,6 +1438,7 @@ class MqttCallbackCore {
onLifecycleEventConnectionSuccess: OnLifecycleEventConnectionSuccess? = nil,
onLifecycleEventConnectionFailure: OnLifecycleEventConnectionFailure? = nil,
onLifecycleEventDisconnection: OnLifecycleEventDisconnection? = nil,
onWebsocketInterceptor: OnWebSocketHandshakeIntercept? = nil,
data: AnyObject? = nil) {

self.onPublishReceivedCallback = onPublishReceivedCallback ?? { (_) in return }
Expand All @@ -1403,6 +1447,7 @@ class MqttCallbackCore {
self.onLifecycleEventConnectionSuccess = onLifecycleEventConnectionSuccess ?? { (_) in return}
self.onLifecycleEventConnectionFailure = onLifecycleEventConnectionFailure ?? { (_) in return}
self.onLifecycleEventDisconnection = onLifecycleEventDisconnection ?? { (_) in return}
self.onWebsocketInterceptor = onWebsocketInterceptor
}

/// Calling this function performs a manual retain on the MqttShutdownCallbackCore.
Expand Down
Loading

0 comments on commit 0d66f15

Please sign in to comment.