diff --git a/Sources/SwiftProtobuf/BinaryEncodingSizeVisitor.swift b/Sources/SwiftProtobuf/BinaryEncodingSizeVisitor.swift index 685cd22a2..369c33d6f 100644 --- a/Sources/SwiftProtobuf/BinaryEncodingSizeVisitor.swift +++ b/Sources/SwiftProtobuf/BinaryEncodingSizeVisitor.swift @@ -19,11 +19,15 @@ import Foundation /// properly sized `Data` or `UInt8` array can be pre-allocated before /// serialization. internal struct BinaryEncodingSizeVisitor: Visitor { + + var sizeCache: BinarySizeCache /// Accumulates the required size of the message during traversal. var serializedSize: Int = 0 - init() {} + init(sizeCache: BinarySizeCache) { + self.sizeCache = sizeCache + } mutating func visitUnknown(bytes: Data) throws { serializedSize += bytes.count @@ -246,7 +250,7 @@ internal struct BinaryEncodingSizeVisitor: Visitor { fieldNumber: Int) throws { let tagSize = FieldTag(fieldNumber: fieldNumber, wireFormat: .lengthDelimited).encodedSize - let messageSize = try value.serializedDataSize() + let messageSize = try sizeCache.getSerializedDataSize(value) serializedSize += tagSize + Varint.encodedSize(of: UInt64(messageSize)) + messageSize } @@ -257,7 +261,7 @@ internal struct BinaryEncodingSizeVisitor: Visitor { wireFormat: .lengthDelimited).encodedSize serializedSize += value.count * tagSize for v in value { - let messageSize = try v.serializedDataSize() + let messageSize = try sizeCache.getSerializedDataSize(v) serializedSize += Varint.encodedSize(of: UInt64(messageSize)) + messageSize } @@ -290,7 +294,7 @@ internal struct BinaryEncodingSizeVisitor: Visitor { let tagSize = FieldTag(fieldNumber: fieldNumber, wireFormat: .lengthDelimited).encodedSize for (k,v) in value { - var sizer = BinaryEncodingSizeVisitor() + var sizer = BinaryEncodingSizeVisitor(sizeCache: sizeCache) try KeyType.visitSingular(value: k, fieldNumber: 1, with: &sizer) try ValueType.visitSingular(value: v, fieldNumber: 2, with: &sizer) let entrySize = sizer.serializedSize @@ -307,7 +311,7 @@ internal struct BinaryEncodingSizeVisitor: Visitor { let tagSize = FieldTag(fieldNumber: fieldNumber, wireFormat: .lengthDelimited).encodedSize for (k,v) in value { - var sizer = BinaryEncodingSizeVisitor() + var sizer = BinaryEncodingSizeVisitor(sizeCache: sizeCache) try KeyType.visitSingular(value: k, fieldNumber: 1, with: &sizer) try sizer.visitSingularEnumField(value: v, fieldNumber: 2) let entrySize = sizer.serializedSize @@ -324,7 +328,7 @@ internal struct BinaryEncodingSizeVisitor: Visitor { let tagSize = FieldTag(fieldNumber: fieldNumber, wireFormat: .lengthDelimited).encodedSize for (k,v) in value { - var sizer = BinaryEncodingSizeVisitor() + var sizer = BinaryEncodingSizeVisitor(sizeCache: sizeCache) try KeyType.visitSingular(value: k, fieldNumber: 1, with: &sizer) try sizer.visitSingularMessageField(value: v, fieldNumber: 2) let entrySize = sizer.serializedSize @@ -338,7 +342,7 @@ internal struct BinaryEncodingSizeVisitor: Visitor { start: Int, end: Int ) throws { - var sizer = BinaryEncodingMessageSetSizeVisitor() + var sizer = BinaryEncodingMessageSetSizeVisitor(sizeCache: sizeCache) try fields.traverse(visitor: &sizer, start: start, end: end) serializedSize += sizer.serializedSize } @@ -348,16 +352,20 @@ internal extension BinaryEncodingSizeVisitor { // Helper Visitor to compute the sizes when writing out the extensions as MessageSets. internal struct BinaryEncodingMessageSetSizeVisitor: SelectiveVisitor { + let sizeCache: BinarySizeCache + var serializedSize: Int = 0 - init() {} + init(sizeCache: BinarySizeCache) { + self.sizeCache = sizeCache + } mutating func visitSingularMessageField(value: M, fieldNumber: Int) throws { var groupSize = WireFormat.MessageSet.itemTagsEncodedSize groupSize += Varint.encodedSize(of: Int32(fieldNumber)) - let messageSize = try value.serializedDataSize() + let messageSize = try sizeCache.getSerializedDataSize(value) groupSize += Varint.encodedSize(of: UInt64(messageSize)) + messageSize serializedSize += groupSize diff --git a/Sources/SwiftProtobuf/BinaryEncodingVisitor.swift b/Sources/SwiftProtobuf/BinaryEncodingVisitor.swift index 03a7b81f8..44499d181 100644 --- a/Sources/SwiftProtobuf/BinaryEncodingVisitor.swift +++ b/Sources/SwiftProtobuf/BinaryEncodingVisitor.swift @@ -19,6 +19,8 @@ import Foundation internal struct BinaryEncodingVisitor: Visitor { var encoder: BinaryEncoder + + let sizeCache: BinarySizeCache /// Creates a new visitor that writes the binary-coded message into the memory /// at the given pointer. @@ -26,12 +28,14 @@ internal struct BinaryEncodingVisitor: Visitor { /// - Precondition: `pointer` must point to an allocated block of memory that /// is large enough to hold the entire encoded message. For performance /// reasons, the encoder does not make any attempts to verify this. - init(forWritingInto pointer: UnsafeMutablePointer) { + init(forWritingInto pointer: UnsafeMutablePointer, sizeCache: BinarySizeCache) { encoder = BinaryEncoder(forWritingInto: pointer) + self.sizeCache = sizeCache } - init(encoder: BinaryEncoder) { + init(encoder: BinaryEncoder, sizeCache: BinarySizeCache) { self.encoder = encoder + self.sizeCache = sizeCache } mutating func visitUnknown(bytes: Data) throws { @@ -106,7 +110,7 @@ internal struct BinaryEncodingVisitor: Visitor { mutating func visitSingularMessageField(value: M, fieldNumber: Int) throws { encoder.startField(fieldNumber: fieldNumber, wireFormat: .lengthDelimited) - let length = try value.serializedDataSize() + let length = try sizeCache.getSerializedDataSize(value) encoder.putVarInt(value: length) try value.traverse(visitor: &self) } @@ -269,7 +273,7 @@ internal struct BinaryEncodingVisitor: Visitor { ) throws { for (k,v) in value { encoder.startField(fieldNumber: fieldNumber, wireFormat: .lengthDelimited) - var sizer = BinaryEncodingSizeVisitor() + var sizer = BinaryEncodingSizeVisitor(sizeCache: sizeCache) try KeyType.visitSingular(value: k, fieldNumber: 1, with: &sizer) try ValueType.visitSingular(value: v, fieldNumber: 2, with: &sizer) let entrySize = sizer.serializedSize @@ -286,7 +290,7 @@ internal struct BinaryEncodingVisitor: Visitor { ) throws where ValueType.RawValue == Int { for (k,v) in value { encoder.startField(fieldNumber: fieldNumber, wireFormat: .lengthDelimited) - var sizer = BinaryEncodingSizeVisitor() + var sizer = BinaryEncodingSizeVisitor(sizeCache: sizeCache) try KeyType.visitSingular(value: k, fieldNumber: 1, with: &sizer) try sizer.visitSingularEnumField(value: v, fieldNumber: 2) let entrySize = sizer.serializedSize @@ -303,7 +307,7 @@ internal struct BinaryEncodingVisitor: Visitor { ) throws { for (k,v) in value { encoder.startField(fieldNumber: fieldNumber, wireFormat: .lengthDelimited) - var sizer = BinaryEncodingSizeVisitor() + var sizer = BinaryEncodingSizeVisitor(sizeCache: sizeCache) try KeyType.visitSingular(value: k, fieldNumber: 1, with: &sizer) try sizer.visitSingularMessageField(value: v, fieldNumber: 2) let entrySize = sizer.serializedSize @@ -318,7 +322,7 @@ internal struct BinaryEncodingVisitor: Visitor { start: Int, end: Int ) throws { - var subVisitor = BinaryEncodingMessageSetVisitor(encoder: encoder) + var subVisitor = BinaryEncodingMessageSetVisitor(encoder: encoder, sizeCache: sizeCache) try fields.traverse(visitor: &subVisitor, start: start, end: end) encoder = subVisitor.encoder } @@ -329,9 +333,11 @@ internal extension BinaryEncodingVisitor { // Helper Visitor to when writing out the extensions as MessageSets. internal struct BinaryEncodingMessageSetVisitor: SelectiveVisitor { var encoder: BinaryEncoder + var sizeCache: BinarySizeCache - init(encoder: BinaryEncoder) { + init(encoder: BinaryEncoder, sizeCache: BinarySizeCache) { self.encoder = encoder + self.sizeCache = sizeCache } mutating func visitSingularMessageField(value: M, fieldNumber: Int) throws { @@ -344,10 +350,10 @@ internal extension BinaryEncodingVisitor { // Use a normal BinaryEncodingVisitor so any message fields end up in the // normal wire format (instead of MessageSet format). - let length = try value.serializedDataSize() + let length = try sizeCache.getSerializedDataSize(value) encoder.putVarInt(value: length) // Create the sub encoder after writing the length. - var subVisitor = BinaryEncodingVisitor(encoder: encoder) + var subVisitor = BinaryEncodingVisitor(encoder: encoder, sizeCache: sizeCache) try value.traverse(visitor: &subVisitor) encoder = subVisitor.encoder diff --git a/Sources/SwiftProtobuf/Message+BinaryAdditions.swift b/Sources/SwiftProtobuf/Message+BinaryAdditions.swift index a4ba3190a..e4d190958 100644 --- a/Sources/SwiftProtobuf/Message+BinaryAdditions.swift +++ b/Sources/SwiftProtobuf/Message+BinaryAdditions.swift @@ -14,6 +14,24 @@ import Foundation +final class BinarySizeCache { + private var cachedSizes: [Int: Int] = [:] + + func getSerializedDataSize(_ message: M) throws -> Int { + if let cacheKey = (message._messageSizeCacheKey.map { Int(bitPattern: $0) }) { + if let cachedSize = cachedSizes[cacheKey] { + return cachedSize + } else { + let computedSize = try message.serializedDataSize(sizeCache: self) + cachedSizes[cacheKey] = computedSize + return computedSize + } + } else { + return try message.serializedDataSize(sizeCache: self) + } + } +} + /// Binary encoding and decoding methods for messages. public extension Message { /// Returns a `Data` value containing the Protocol Buffer binary format @@ -31,10 +49,11 @@ public extension Message { if !partial && !isInitialized { throw BinaryEncodingError.missingRequiredFields } - let requiredSize = try serializedDataSize() + let sizeCache = BinarySizeCache() + let requiredSize = try sizeCache.getSerializedDataSize(self) var data = Data(count: requiredSize) try data.withUnsafeMutableBytes { (pointer: UnsafeMutablePointer) in - var visitor = BinaryEncodingVisitor(forWritingInto: pointer) + var visitor = BinaryEncodingVisitor(forWritingInto: pointer, sizeCache: sizeCache) try traverse(visitor: &visitor) // Currently not exposing this from the api because it really would be // an internal error in the library and should never happen. @@ -46,11 +65,11 @@ public extension Message { /// Returns the size in bytes required to encode the message in binary format. /// This is used by `serializedData()` to precalculate the size of the buffer /// so that encoding can proceed without bounds checks or reallocation. - internal func serializedDataSize() throws -> Int { + fileprivate func serializedDataSize(sizeCache: BinarySizeCache) throws -> Int { // Note: since this api is internal, it doesn't currently worry about // needing a partial argument to handle proto2 syntax required fields. // If this become public, it will need that added. - var visitor = BinaryEncodingSizeVisitor() + var visitor = BinaryEncodingSizeVisitor(sizeCache: sizeCache) try traverse(visitor: &visitor) return visitor.serializedSize } diff --git a/Sources/SwiftProtobuf/Message.swift b/Sources/SwiftProtobuf/Message.swift index 0a721e706..d7510f0a9 100644 --- a/Sources/SwiftProtobuf/Message.swift +++ b/Sources/SwiftProtobuf/Message.swift @@ -117,6 +117,8 @@ public protocol Message: CustomDebugStringConvertible { /// normal `Equatable`. `Equatable` is provided with specific generated /// types. func isEqualTo(message: Message) -> Bool + + var _messageSizeCacheKey: UnsafeMutableRawPointer? { get } } public extension Message { diff --git a/Sources/protoc-gen-swift/MessageGenerator.swift b/Sources/protoc-gen-swift/MessageGenerator.swift index c485922dd..d23ef99fc 100644 --- a/Sources/protoc-gen-swift/MessageGenerator.swift +++ b/Sources/protoc-gen-swift/MessageGenerator.swift @@ -210,6 +210,11 @@ class MessageGenerator { storage.generateTypeDeclaration(printer: &p) p.print("\n") storage.generateUniqueStorage(printer: &p) + p.print("\n") + storage.generateMessageSizeCacheKey(printer: &p) + } else { + p.print("\n") + generateMessageSizeCacheKey(printer: &p) } p.print("\n") generateMessageImplementationBase(printer: &p) @@ -346,6 +351,14 @@ class MessageGenerator { p.print("}\n") } + private func generateMessageSizeCacheKey(printer p: inout CodePrinter) { + p.print("public var _messageSizeCacheKey: UnsafeMutableRawPointer? {\n") + p.indent() + p.print("return nil\n") + p.outdent() + p.print("}\n") + } + private func generateMessageImplementationBase(printer p: inout CodePrinter) { p.print("\(visibility)func _protobuf_generated_isEqualTo(other: \(swiftFullName)) -> Bool {\n") p.indent() diff --git a/Sources/protoc-gen-swift/MessageStorageClassGenerator.swift b/Sources/protoc-gen-swift/MessageStorageClassGenerator.swift index 9e87165da..513afd0e0 100644 --- a/Sources/protoc-gen-swift/MessageStorageClassGenerator.swift +++ b/Sources/protoc-gen-swift/MessageStorageClassGenerator.swift @@ -75,6 +75,14 @@ class MessageStorageClassGenerator { p.outdent() p.print("}\n") } + + func generateMessageSizeCacheKey(printer p: inout CodePrinter) { + p.print("public var _messageSizeCacheKey: UnsafeMutableRawPointer? {\n") + p.indent() + p.print("return Unmanaged.passUnretained(_storage).toOpaque()\n") + p.outdent() + p.print("}\n") + } func generatePreTraverse(printer p: inout CodePrinter) { // Nothing diff --git a/SwiftProtobuf.xcodeproj/project.pbxproj b/SwiftProtobuf.xcodeproj/project.pbxproj index bc2c5d2c8..aabc50694 100644 --- a/SwiftProtobuf.xcodeproj/project.pbxproj +++ b/SwiftProtobuf.xcodeproj/project.pbxproj @@ -911,7 +911,10 @@ "_____Configs_" /* Resources */, "____Products_" /* Products */, ); + indentWidth = 2; sourceTree = ""; + tabWidth = 2; + usesTabs = 0; }; "____Products_" /* Products */ = { isa = PBXGroup;