diff --git a/Sources/LiveKit/DataStream/Outgoing/OutgoingStreamManager.swift b/Sources/LiveKit/DataStream/Outgoing/OutgoingStreamManager.swift index c0a80cec6..4798c08cc 100644 --- a/Sources/LiveKit/DataStream/Outgoing/OutgoingStreamManager.swift +++ b/Sources/LiveKit/DataStream/Outgoing/OutgoingStreamManager.swift @@ -271,6 +271,7 @@ extension Livekit_DataStream.Header { $0.totalLength = UInt64(totalLength) } $0.attributes = streamInfo.attributes + $0.encryptionType = streamInfo.encryptionType.toPBType() $0.contentHeader = Livekit_DataStream.Header.OneOf_ContentHeader(streamInfo) } } diff --git a/Sources/LiveKit/Types/ParticipantPermissions.swift b/Sources/LiveKit/Types/ParticipantPermissions.swift index 29d2e04ba..bc4c1d8ae 100644 --- a/Sources/LiveKit/Types/ParticipantPermissions.swift +++ b/Sources/LiveKit/Types/ParticipantPermissions.swift @@ -42,12 +42,22 @@ public class ParticipantPermissions: NSObject, @unchecked Sendable { @objc public let recorder: Bool + /// Indicates participant can update own metadata and attributes + @objc + public let canUpdateMetadata: Bool + + /// Indicates participant can subscribe to metrics + @objc + public let canSubscribeMetrics: Bool + init(canSubscribe: Bool = false, canPublish: Bool = false, canPublishData: Bool = false, canPublishSources: Set = [], hidden: Bool = false, - recorder: Bool = false) + recorder: Bool = false, + canUpdateMetadata: Bool = false, + canSubscribeMetrics: Bool = false) { self.canSubscribe = canSubscribe self.canPublish = canPublish @@ -55,6 +65,8 @@ public class ParticipantPermissions: NSObject, @unchecked Sendable { self.canPublishSources = Set(canPublishSources.map(\.rawValue)) self.hidden = hidden self.recorder = recorder + self.canUpdateMetadata = canUpdateMetadata + self.canSubscribeMetrics = canSubscribeMetrics } // MARK: - Equal @@ -66,7 +78,9 @@ public class ParticipantPermissions: NSObject, @unchecked Sendable { canPublishData == other.canPublishData && canPublishSources == other.canPublishSources && hidden == other.hidden && - recorder == other.recorder + recorder == other.recorder && + canUpdateMetadata == other.canUpdateMetadata && + canSubscribeMetrics == other.canSubscribeMetrics } override public var hash: Int { @@ -77,6 +91,8 @@ public class ParticipantPermissions: NSObject, @unchecked Sendable { hasher.combine(canPublishSources) hasher.combine(hidden) hasher.combine(recorder) + hasher.combine(canUpdateMetadata) + hasher.combine(canSubscribeMetrics) return hasher.finalize() } } @@ -88,6 +104,8 @@ extension Livekit_ParticipantPermission { canPublishData: canPublishData, canPublishSources: Set(canPublishSources.map { $0.toLKType() }), hidden: hidden, - recorder: recorder) + recorder: recorder, + canUpdateMetadata: canUpdateMetadata, + canSubscribeMetrics: canSubscribeMetrics) } } diff --git a/Tests/LiveKitCoreTests/DataStream/ByteStreamInfoTests.swift b/Tests/LiveKitCoreTests/DataStream/ByteStreamInfoTests.swift index 4450f0d6b..995d7251c 100644 --- a/Tests/LiveKitCoreTests/DataStream/ByteStreamInfoTests.swift +++ b/Tests/LiveKitCoreTests/DataStream/ByteStreamInfoTests.swift @@ -25,7 +25,7 @@ class ByteStreamInfoTests: LKTestCase { timestamp: Date(timeIntervalSince1970: 100), totalLength: 128, attributes: ["key": "value"], - encryptionType: .none, + encryptionType: .gcm, mimeType: "image/jpeg", name: "filename.bin" ) @@ -36,15 +36,17 @@ class ByteStreamInfoTests: LKTestCase { XCTAssertEqual(header.timestamp, Int64(info.timestamp.timeIntervalSince1970 * TimeInterval(1000))) XCTAssertEqual(header.totalLength, UInt64(info.totalLength ?? -1)) XCTAssertEqual(header.attributes, info.attributes) + XCTAssertEqual(header.encryptionType.rawValue, info.encryptionType.rawValue) XCTAssertEqual(header.byteHeader.name, info.name) - let newInfo = ByteStreamInfo(header, header.byteHeader, .none) + let newInfo = ByteStreamInfo(header, header.byteHeader, .gcm) XCTAssertEqual(newInfo.id, info.id) XCTAssertEqual(newInfo.mimeType, info.mimeType) XCTAssertEqual(newInfo.topic, info.topic) XCTAssertEqual(newInfo.timestamp, info.timestamp) XCTAssertEqual(newInfo.totalLength, info.totalLength) XCTAssertEqual(newInfo.attributes, info.attributes) + XCTAssertEqual(newInfo.encryptionType, info.encryptionType) XCTAssertEqual(newInfo.name, info.name) } } diff --git a/Tests/LiveKitCoreTests/DataStream/TextStreamInfoTests.swift b/Tests/LiveKitCoreTests/DataStream/TextStreamInfoTests.swift index d0a09b292..466b40437 100644 --- a/Tests/LiveKitCoreTests/DataStream/TextStreamInfoTests.swift +++ b/Tests/LiveKitCoreTests/DataStream/TextStreamInfoTests.swift @@ -25,7 +25,7 @@ class TextStreamInfoTests: LKTestCase { timestamp: Date(timeIntervalSince1970: 100), totalLength: 128, attributes: ["key": "value"], - encryptionType: .none, + encryptionType: .gcm, operationType: .reaction, version: 10, replyToStreamID: "replyID", @@ -38,18 +38,20 @@ class TextStreamInfoTests: LKTestCase { XCTAssertEqual(header.timestamp, Int64(info.timestamp.timeIntervalSince1970 * TimeInterval(1000))) XCTAssertEqual(header.totalLength, UInt64(info.totalLength ?? -1)) XCTAssertEqual(header.attributes, info.attributes) + XCTAssertEqual(header.encryptionType.rawValue, info.encryptionType.rawValue) XCTAssertEqual(header.textHeader.operationType.rawValue, info.operationType.rawValue) XCTAssertEqual(header.textHeader.version, Int32(info.version)) XCTAssertEqual(header.textHeader.replyToStreamID, info.replyToStreamID) XCTAssertEqual(header.textHeader.attachedStreamIds, info.attachedStreamIDs) XCTAssertEqual(header.textHeader.generated, info.generated) - let newInfo = TextStreamInfo(header, header.textHeader, .none) + let newInfo = TextStreamInfo(header, header.textHeader, .gcm) XCTAssertEqual(newInfo.id, info.id) XCTAssertEqual(newInfo.topic, info.topic) XCTAssertEqual(newInfo.timestamp, info.timestamp) XCTAssertEqual(newInfo.totalLength, info.totalLength) XCTAssertEqual(newInfo.attributes, info.attributes) + XCTAssertEqual(newInfo.encryptionType, info.encryptionType) XCTAssertEqual(newInfo.operationType, info.operationType) XCTAssertEqual(newInfo.version, info.version) XCTAssertEqual(newInfo.replyToStreamID, info.replyToStreamID) diff --git a/Tests/LiveKitCoreTests/Proto/ProtoConverterTests.swift b/Tests/LiveKitCoreTests/Proto/ProtoConverterTests.swift new file mode 100644 index 000000000..6d88c1365 --- /dev/null +++ b/Tests/LiveKitCoreTests/Proto/ProtoConverterTests.swift @@ -0,0 +1,151 @@ +/* + * Copyright 2025 LiveKit + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@testable import LiveKit +#if canImport(LiveKitTestSupport) +import LiveKitTestSupport +#endif + +class ProtoConverterTests: LKTestCase { + func testParticipantPermissions() { + let errors = Comparator.compareStructures( + proto: Livekit_ParticipantPermission(), + sdk: ParticipantPermissions(), + excludedFields: ["agent"], // deprecated + allowedTypeMismatches: ["canPublishSources"] // Array vs Set + ) + + XCTAssert(errors.isEmpty, errors.description) + } +} + +enum Comparator { + enum ComparisonError: Error, CustomStringConvertible { + case missingField(String) + case extraField(String) + case typeMismatch(field: String, proto: String, sdk: String) + + var description: String { + switch self { + case let .missingField(field): + "Missing field: '\(field)'" + case let .extraField(field): + "Extra field: '\(field)'" + case let .typeMismatch(field, proto, sdk): + "Type mismatch for '\(field)': proto has \(proto), sdk has \(sdk)" + } + } + } + + struct FieldInfo { + let name: String + let type: String + let nonOptionalType: String + } + + static func extractFields(from instance: some Any, excludedFields: Set = []) -> [FieldInfo] { + let mirror = Mirror(reflecting: instance) + var fields: [FieldInfo] = [] + var backingFields: Set = [] + + // Collect all backing fields + for child in mirror.children { + guard let label = child.label, label.hasPrefix("_") else { continue } + backingFields.insert(String(label.dropFirst())) // Remove the underscore + } + + for child in mirror.children { + guard let label = child.label else { continue } + + // Skip excluded/unknown fields + if excludedFields.contains(label) || label == "unknownFields" { + continue + } + + // Skip private backing fields (they have public computed properties) + if label.hasPrefix("_"), backingFields.contains(String(label.dropFirst())) { + // But add the public version instead + let publicName = String(label.dropFirst()) + let typeString = String(describing: type(of: child.value)) + let nonOptional = extractNonOptionalType(from: typeString) + + if !fields.contains(where: { $0.name == publicName }) { + fields.append(FieldInfo(name: publicName, type: typeString, nonOptionalType: nonOptional)) + } + continue + } + + // Skip other private fields + if label.hasPrefix("_") { + continue + } + + let typeString = String(describing: type(of: child.value)) + let nonOptional = extractNonOptionalType(from: typeString) + + fields.append(FieldInfo(name: label, type: typeString, nonOptionalType: nonOptional)) + } + + return fields.sorted { $0.name < $1.name } + } + + static func extractNonOptionalType(from typeString: String) -> String { + if typeString.hasPrefix("Optional<"), typeString.hasSuffix(">") { + let start = typeString.index(typeString.startIndex, offsetBy: 9) + let end = typeString.index(before: typeString.endIndex) + return String(typeString[start ..< end]) + } + return typeString + } + + static func compareStructures( + proto: some Any, + sdk: some Any, + excludedFields: Set = [], + allowedTypeMismatches: Set = [] + ) -> [ComparisonError] { + let protoFields = extractFields(from: proto, excludedFields: excludedFields) + let sdkFields = extractFields(from: sdk, excludedFields: excludedFields) + + var errors: [ComparisonError] = [] + + let protoFieldMap = Dictionary(uniqueKeysWithValues: protoFields.map { ($0.name, $0) }) + let sdkFieldMap = Dictionary(uniqueKeysWithValues: sdkFields.map { ($0.name, $0) }) + + for protoField in protoFields { + guard let sdkField = sdkFieldMap[protoField.name] else { + errors.append(.missingField(protoField.name)) + continue + } + + if protoField.nonOptionalType != sdkField.nonOptionalType, !allowedTypeMismatches.contains(protoField.name) { + errors.append(.typeMismatch( + field: protoField.name, + proto: protoField.type, + sdk: sdkField.type + )) + } + } + + for sdkField in sdkFields { + if protoFieldMap[sdkField.name] == nil { + errors.append(.extraField(sdkField.name)) + } + } + + return errors + } +}