Skip to content

Commit

Permalink
Update Errors in Mqtt Client (#268)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiazhvera authored Jul 18, 2024
1 parent e0de3a7 commit 3ae634e
Show file tree
Hide file tree
Showing 11 changed files with 150 additions and 108 deletions.
6 changes: 5 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ packageTargets.append(contentsOf: [
exclude: awsCMqttExcludes,
cSettings: cSettings
),
.systemLibrary(
name: "LibNative"
),
.target(
name: "AwsCommonRuntimeKit",
dependencies: [ "AwsCAuth",
Expand All @@ -286,7 +289,8 @@ packageTargets.append(contentsOf: [
"AwsCCommon",
"AwsCChecksums",
"AwsCEventStream",
"AwsCMqtt"],
"AwsCMqtt",
"LibNative"],
path: "Source/AwsCommonRuntimeKit",
resources: [
.copy("PrivacyInfo.xcprivacy")
Expand Down
3 changes: 3 additions & 0 deletions Source/AwsCommonRuntimeKit/CommonRuntimeKit.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import AwsCEventStream
import AwsCAuth
import AwsCMqtt
import LibNative

/**
* Initializes the library.
Expand All @@ -14,6 +15,7 @@ public struct CommonRuntimeKit {
aws_auth_library_init(allocator.rawValue)
aws_event_stream_library_init(allocator.rawValue)
aws_mqtt_library_init(allocator.rawValue)
aws_register_error_info(&s_crt_swift_error_list)
}

/**
Expand All @@ -22,6 +24,7 @@ public struct CommonRuntimeKit {
* Warning: It will hang if you are still holding references to any CRT objects such as HostResolver.
*/
public static func cleanUp() {
aws_unregister_error_info(&s_crt_swift_error_list)
aws_mqtt_library_clean_up()
aws_event_stream_library_clean_up()
aws_auth_library_clean_up()
Expand Down
8 changes: 6 additions & 2 deletions Source/AwsCommonRuntimeKit/crt/CommonRuntimeError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@ public struct CRTError: Equatable {
public let message: String
public let name: String

public init<T: BinaryInteger>(code: T) {
public init<T: BinaryInteger>(code: T, context: String? = nil) {
if code > INT32_MAX || code <= 0 {
self.code = Int32(AWS_ERROR_UNKNOWN.rawValue)
} else {
self.code = Int32(code)
}
self.message = String(cString: aws_error_str(self.code))
var message = String(cString: aws_error_str(self.code))
if let context {
message += ": " + context
}
self.message = message
self.name = String(cString: aws_error_name(self.code))
}

Expand Down
33 changes: 1 addition & 32 deletions Source/AwsCommonRuntimeKit/crt/Utilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import struct Foundation.Date
import struct Foundation.Data
import struct Foundation.TimeInterval
import AwsCCal
import LibNative

/// This class is used to add reference counting to stuff that do not support it
/// like Structs, Closures, and Protocols etc by wrapping it in a Class.
Expand Down Expand Up @@ -223,38 +224,6 @@ extension aws_array_list {
}
}

/// Convert a native aws_byte_cursor pointer into a String?
func convertAwsByteCursorToOptionalString(_ awsByteCursor: UnsafePointer<aws_byte_cursor>?) -> String? {
guard let cursor = awsByteCursor?.pointee else {
return nil
}
return cursor.toString()
}

/// Convert a native uint16_t pointer into a Swift UInt16?
func convertOptionalUInt16(_ pointer: UnsafePointer<UInt16>?) -> UInt16? {
guard let validPointer = pointer else {
return nil
}
return validPointer.pointee
}

/// Convert a native uint32_t pointer into a Swift UInt32?
func convertOptionalUInt32(_ pointer: UnsafePointer<UInt32>?) -> UInt32? {
guard let validPointer = pointer else {
return nil
}
return validPointer.pointee
}

/// Convert a native bool pointer to an optional Swift Bool
func convertOptionalBool(_ pointer: UnsafePointer<Bool>?) -> Bool? {
guard let validPointer = pointer else {
return nil
}
return validPointer.pointee
}

extension Bool {
var uintValue: UInt32 {
return self ? 1 : 0
Expand Down
17 changes: 10 additions & 7 deletions Source/AwsCommonRuntimeKit/mqtt/Mqtt5Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import AwsCMqtt
import AwsCIo
import LibNative

// MARK: - Callback Data Classes

Expand Down Expand Up @@ -263,13 +264,12 @@ public class Mqtt5ClientCore {
try self.rwlock.read {
// Validate close() has not been called on client.
guard let rawValue = self.rawValue else {
// TODO add new error type for client closed
throw CommonRunTimeError.crtError(CRTError.makeFromLastError())
throw CommonRunTimeError.crtError(CRTError(code: AWS_CRT_SWIFT_MQTT_CLIENT_CLOSED.rawValue))
}
let errorCode = aws_mqtt5_client_start(rawValue)

if errorCode != AWS_OP_SUCCESS {
throw CommonRunTimeError.crtError(CRTError(code: errorCode))
throw CommonRunTimeError.crtError(CRTError.makeFromLastError())
}
}
}
Expand All @@ -286,7 +286,7 @@ public class Mqtt5ClientCore {
try self.rwlock.read {
// Validate close() has not been called on client.
guard let rawValue = self.rawValue else {
throw CommonRunTimeError.crtError(CRTError.makeFromLastError())
throw CommonRunTimeError.crtError(CRTError(code: AWS_CRT_SWIFT_MQTT_CLIENT_CLOSED.rawValue))
}

var errorCode: Int32 = 0
Expand Down Expand Up @@ -327,7 +327,8 @@ public class Mqtt5ClientCore {
// Validate close() has not been called on client.
guard let rawValue = self.rawValue else {
continuationCore.release()
return continuation.resume(throwing: CommonRunTimeError.crtError(CRTError.makeFromLastError()))
return continuation.resume(throwing: CommonRunTimeError.crtError(
CRTError(code: AWS_CRT_SWIFT_MQTT_CLIENT_CLOSED.rawValue)))
}
let result = aws_mqtt5_client_subscribe(rawValue, subscribePacketPointer, &callbackOptions)
guard result == AWS_OP_SUCCESS else {
Expand Down Expand Up @@ -364,7 +365,8 @@ public class Mqtt5ClientCore {
// Validate close() has not been called on client.
guard let rawValue = self.rawValue else {
continuationCore.release()
return continuation.resume(throwing: CommonRunTimeError.crtError(CRTError.makeFromLastError()))
return continuation.resume(throwing: CommonRunTimeError.crtError(
CRTError(code: AWS_ERROR_INVALID_ARGUMENT.rawValue, context: "Mqtt client is closed.")))
}

let result = aws_mqtt5_client_publish(rawValue, publishPacketPointer, &callbackOptions)
Expand Down Expand Up @@ -398,7 +400,8 @@ public class Mqtt5ClientCore {
// Validate close() has not been called on client.
guard let rawValue = self.rawValue else {
continuationCore.release()
return continuation.resume(throwing: CommonRunTimeError.crtError(CRTError.makeFromLastError()))
return continuation.resume(throwing: CommonRunTimeError.crtError(
CRTError(code: AWS_ERROR_INVALID_ARGUMENT.rawValue, context: "Mqtt client is closed.")))
}
let result = aws_mqtt5_client_unsubscribe(rawValue, unsubscribePacketPointer, &callbackOptions)
guard result == AWS_OP_SUCCESS else {
Expand Down
5 changes: 0 additions & 5 deletions Source/AwsCommonRuntimeKit/mqtt/Mqtt5Enums.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@

import AwsCMqtt

// TODO this is temporary. We will replace this with aws-crt-swift error codes.
enum MqttError: Error {
case validation(message: String)
}

/// MQTT message delivery quality of service.
/// Enum values match `MQTT5 spec <https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901234>`__ encoding values.
public enum QoS {
Expand Down
27 changes: 18 additions & 9 deletions Source/AwsCommonRuntimeKit/mqtt/Mqtt5Options.swift
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,23 @@ public class MqttConnectOptions: CStruct {
func validateConversionToNative() throws {
if let keepAliveInterval {
if keepAliveInterval < 0 || keepAliveInterval > Double(UInt16.max) {
throw MqttError.validation(message: "Invalid keepAliveInterval value")
throw CommonRunTimeError.crtError(CRTError(code: AWS_ERROR_INVALID_ARGUMENT.rawValue,
context: "Invalid keepAliveInterval value"))
}
}

do {
_ = try sessionExpiryInterval?.secondUInt32()
} catch {
throw MqttError.validation(message: "Invalid sessionExpiryInterval value")
throw CommonRunTimeError.crtError(CRTError(code: AWS_ERROR_INVALID_ARGUMENT.rawValue,
context: "Invalid sessionExpiryInterval value"))
}

do {
_ = try willDelayInterval?.secondUInt32()
} catch {
throw MqttError.validation(message: "Invalid willDelayInterval value")
throw CommonRunTimeError.crtError(CRTError(code: AWS_ERROR_INVALID_ARGUMENT.rawValue,
context: "Invalid willDelayInterval value"))
}
}

Expand Down Expand Up @@ -339,36 +342,42 @@ public class MqttClientOptions: CStructWithUserData {
do {
_ = try minReconnectDelay?.millisecondUInt64()
} catch {
throw MqttError.validation(message: "Invalid minReconnectDelay value")
throw CommonRunTimeError.crtError(CRTError(code: AWS_ERROR_INVALID_ARGUMENT.rawValue,
context: "Invalid minReconnectDelay value"))
}

do {
_ = try maxReconnectDelay?.millisecondUInt64()
} catch {
throw MqttError.validation(message: "Invalid maxReconnectDelay value")
throw CommonRunTimeError.crtError(CRTError(code: AWS_ERROR_INVALID_ARGUMENT.rawValue,
context: "Invalid maxReconnectDelay value"))
}

do {
_ = try minConnectedTimeToResetReconnectDelay?.millisecondUInt64()
} catch {
throw MqttError.validation(message: "Invalid minConnectedTimeToResetReconnectDelay value")
throw CommonRunTimeError.crtError(CRTError(code: AWS_ERROR_INVALID_ARGUMENT.rawValue,
context: "Invalid minConnectedTimeToResetReconnectDelay value"))
}

do {
_ = try pingTimeout?.millisecondUInt32()
} catch {
throw MqttError.validation(message: "Invalid pingTimeout value")
throw CommonRunTimeError.crtError(CRTError(code: AWS_ERROR_INVALID_ARGUMENT.rawValue,
context: "Invalid pingTimeout value"))
}

do {
_ = try connackTimeout?.millisecondUInt32()
} catch {
throw MqttError.validation(message: "Invalid connackTimeout value")
throw CommonRunTimeError.crtError(CRTError(code: AWS_ERROR_INVALID_ARGUMENT.rawValue,
context: "Invalid connackTimeout value"))
}

if let ackTimeout {
if ackTimeout < 0 || ackTimeout > Double(UInt32.max) {
throw MqttError.validation(message: "Invalid ackTimeout value")
throw CommonRunTimeError.crtError(CRTError(code: AWS_ERROR_INVALID_ARGUMENT.rawValue,
context: "Invalid ackTimeout value"))
}
}
}
Expand Down
44 changes: 23 additions & 21 deletions Source/AwsCommonRuntimeKit/mqtt/Mqtt5Packets.swift
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ public class PublishPacket: CStruct {
self.retain = publishView.retain
self.payloadFormatIndicator = publishView.payload_format != nil ?
PayloadFormatIndicator(publishView.payload_format.pointee) : nil
self.messageExpiryInterval = convertOptionalUInt32(publishView.message_expiry_interval_seconds).map { TimeInterval($0) }
self.topicAlias = convertOptionalUInt16(publishView.topic_alias)
self.responseTopic = convertAwsByteCursorToOptionalString(publishView.response_topic)
self.messageExpiryInterval = (publishView.message_expiry_interval_seconds?.pointee).map { TimeInterval($0) }
self.topicAlias = publishView.topic_alias?.pointee
self.responseTopic = publishView.response_topic?.pointee.toString()
self.correlationData = publishView.correlation_data != nil ?
Data(bytes: publishView.correlation_data!.pointee.ptr, count: publishView.correlation_data!.pointee.len) : nil
var identifier: [UInt32]? = []
Expand All @@ -187,7 +187,7 @@ public class PublishPacket: CStruct {
identifier?.append(subscription_identifier)
}
self.subscriptionIdentifiers = identifier
self.contentType = convertAwsByteCursorToOptionalString(publishView.content_type)
self.contentType = publishView.content_type?.pointee.toString()
self.userProperties = convertOptionalUserProperties(
count: publishView.user_property_count,
userPropertiesPointer: publishView.user_properties)
Expand All @@ -204,7 +204,8 @@ public class PublishPacket: CStruct {
func validateConversionToNative() throws {
if let messageExpiryInterval {
if messageExpiryInterval < 0 || messageExpiryInterval > Double(UInt32.max) {
throw MqttError.validation(message: "Invalid sessionExpiryInterval value")
throw CommonRunTimeError.crtError(CRTError(code: AWS_ERROR_INVALID_ARGUMENT.rawValue,
context: "Invalid sessionExpiryInterval value"))
}
}
}
Expand Down Expand Up @@ -627,9 +628,9 @@ public class DisconnectPacket: CStruct {
let disconnectView = disconnect_view.pointee

self.reasonCode = DisconnectReasonCode(rawValue: Int(disconnectView.reason_code.rawValue))!
self.sessionExpiryInterval = convertOptionalUInt32(disconnectView.session_expiry_interval_seconds).map { TimeInterval($0) }
self.reasonString = convertAwsByteCursorToOptionalString(disconnectView.reason_string)
self.serverReference = convertAwsByteCursorToOptionalString(disconnectView.reason_string)
self.sessionExpiryInterval = (disconnectView.session_expiry_interval_seconds?.pointee).map { TimeInterval($0) }
self.reasonString = disconnectView.reason_string?.pointee.toString()
self.serverReference = disconnectView.reason_string?.pointee.toString()
self.userProperties = convertOptionalUserProperties(
count: disconnectView.user_property_count,
userPropertiesPointer: disconnectView.user_properties)
Expand All @@ -638,7 +639,8 @@ public class DisconnectPacket: CStruct {
func validateConversionToNative() throws {
if let sessionExpiryInterval {
if sessionExpiryInterval < 0 || sessionExpiryInterval > Double(UInt32.max) {
throw MqttError.validation(message: "Invalid sessionExpiryInterval value")
throw CommonRunTimeError.crtError(CRTError(code: AWS_ERROR_INVALID_ARGUMENT.rawValue,
context: "Invalid sessionExpiryInterval value"))
}
}
}
Expand Down Expand Up @@ -774,23 +776,23 @@ public class ConnackPacket {
self.sessionPresent = connackView.session_present
self.reasonCode = ConnectReasonCode(rawValue: Int(connackView.reason_code.rawValue))!
self.sessionExpiryInterval = (connackView.session_expiry_interval?.pointee).map { TimeInterval($0) }
self.receiveMaximum = convertOptionalUInt16(connackView.receive_maximum)
self.receiveMaximum = connackView.receive_maximum?.pointee
if let maximumQosValue = connackView.maximum_qos {
self.maximumQos = QoS(maximumQosValue.pointee)
} else {
self.maximumQos = nil
}
self.retainAvailable = convertOptionalBool(connackView.retain_available)
self.maximumPacketSize = convertOptionalUInt32(connackView.maximum_packet_size)
self.assignedClientIdentifier = convertAwsByteCursorToOptionalString(connackView.assigned_client_identifier)
self.topicAliasMaximum = convertOptionalUInt16(connackView.topic_alias_maximum)
self.reasonString = convertAwsByteCursorToOptionalString(connackView.reason_string)
self.wildcardSubscriptionsAvailable = convertOptionalBool(connackView.wildcard_subscriptions_available)
self.subscriptionIdentifiersAvailable = convertOptionalBool(connackView.subscription_identifiers_available)
self.sharedSubscriptionAvailable = convertOptionalBool(connackView.shared_subscriptions_available)
self.serverKeepAlive = convertOptionalUInt16(connackView.server_keep_alive).map { TimeInterval($0) }
self.responseInformation = convertAwsByteCursorToOptionalString(connackView.response_information)
self.serverReference = convertAwsByteCursorToOptionalString(connackView.server_reference)
self.retainAvailable = connackView.retain_available?.pointee
self.maximumPacketSize = connackView.maximum_packet_size?.pointee
self.assignedClientIdentifier = connackView.assigned_client_identifier?.pointee.toString()
self.topicAliasMaximum = connackView.topic_alias_maximum?.pointee
self.reasonString = connackView.reason_string?.pointee.toString()
self.wildcardSubscriptionsAvailable = connackView.wildcard_subscriptions_available?.pointee
self.subscriptionIdentifiersAvailable = connackView.subscription_identifiers_available?.pointee
self.sharedSubscriptionAvailable = connackView.shared_subscriptions_available?.pointee
self.serverKeepAlive = (connackView.server_keep_alive?.pointee).map { TimeInterval($0) }
self.responseInformation = connackView.response_information?.pointee.toString()
self.serverReference = connackView.server_reference?.pointee.toString()
self.userProperties = convertOptionalUserProperties(
count: connackView.user_property_count,
userPropertiesPointer: connackView.user_properties)
Expand Down
37 changes: 37 additions & 0 deletions Source/LibNative/CommonRuntimeError.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0.

#ifndef SWIFT_COMMON_RUNTIME_ERROR_H
#define SWIFT_COMMON_RUNTIME_ERROR_H
#include <aws/common/common.h>
#include <aws/common/error.h>

/**
* The file introduced the swift error spaces, defines the error code used for aws-crt-swift.
* We defined the error codes here because Swift error handling requires the use of enums, and Swift
* does not support extensible enums, which makes future extensions challenging. Therefore, we chose
* to add a C error space to ensure future-proofing.
*/

#define AWS_CRT_SWIFT_PACKAGE_ID 17

#define AWS_DEFINE_ERROR_INFO_CRT_SWIFT(CODE, STR) [(CODE)-0x4400] = AWS_DEFINE_ERROR_INFO(CODE, STR, "aws-crt-swift")

enum aws_swift_errors {
AWS_CRT_SWIFT_MQTT_CLIENT_CLOSED = AWS_ERROR_ENUM_BEGIN_RANGE(AWS_CRT_SWIFT_PACKAGE_ID),
AWS_CRT_SWIFT_ERROR_END_RANGE = AWS_ERROR_ENUM_END_RANGE(AWS_CRT_SWIFT_PACKAGE_ID),
};


static struct aws_error_info s_crt_swift_errors[] = {
AWS_DEFINE_ERROR_INFO_CRT_SWIFT(
AWS_CRT_SWIFT_MQTT_CLIENT_CLOSED,
"The Mqtt Client is closed.")
};

static struct aws_error_info_list s_crt_swift_error_list = {
.error_list = s_crt_swift_errors,
.count = AWS_ARRAY_SIZE(s_crt_swift_errors),
};

#endif /* SWIFT_COMMON_RUNTIME_ERROR_H */
6 changes: 6 additions & 0 deletions Source/LibNative/module.modulemap
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

module LibNative {
header "CommonRuntimeError.h"
export *
}

Loading

0 comments on commit 3ae634e

Please sign in to comment.