diff --git a/Sources/GRPC/ClientCalls/ClientCall.swift b/Sources/GRPC/ClientCalls/ClientCall.swift index 099aff76d..3e66d0b36 100644 --- a/Sources/GRPC/ClientCalls/ClientCall.swift +++ b/Sources/GRPC/ClientCalls/ClientCall.swift @@ -62,9 +62,8 @@ public protocol StreamingRequestClientCall: ClientCall { /// /// - Parameters: /// - message: The message to - /// - flush: Whether the buffer should be flushed after writing the message. /// - Returns: A future which will be fullfilled when the message has been sent. - func sendMessage(_ message: RequestMessage, flush: Bool) -> EventLoopFuture + func sendMessage(_ message: RequestMessage) -> EventLoopFuture /// Sends a message to the service. /// @@ -73,8 +72,24 @@ public protocol StreamingRequestClientCall: ClientCall { /// - Parameters: /// - message: The message to send. /// - promise: A promise to be fulfilled when the message has been sent. - /// - flush: Whether the buffer should be flushed after writing the message. - func sendMessage(_ message: RequestMessage, promise: EventLoopPromise?, flush: Bool) + func sendMessage(_ message: RequestMessage, promise: EventLoopPromise?) + + /// Sends a sequence of messages to the service. + /// + /// - Important: Callers must terminate the stream of messages by calling `sendEnd()` or `sendEnd(promise:)`. + /// + /// - Parameters: + /// - messages: The sequence of messages to send. + func sendMessages(_ messages: S) -> EventLoopFuture where S.Element == RequestMessage + + /// Sends a sequence of messages to the service. + /// + /// - Important: Callers must terminate the stream of messages by calling `sendEnd()` or `sendEnd(promise:)`. + /// + /// - Parameters: + /// - messages: The sequence of messages to send. + /// - promise: A promise to be fulfilled when all messages have been sent successfully. + func sendMessages(_ messages: S, promise: EventLoopPromise?) where S.Element == RequestMessage /// Returns a future which can be used as a message queue. /// @@ -112,20 +127,36 @@ public protocol UnaryResponseClientCall: ClientCall { } extension StreamingRequestClientCall { - public func sendMessage(_ message: RequestMessage, flush: Bool = true) -> EventLoopFuture { + public func sendMessage(_ message: RequestMessage) -> EventLoopFuture { return self.subchannel.flatMap { channel in - let writeFuture = channel.write(GRPCClientRequestPart.message(_Box(message))) - if flush { - channel.flush() - } - return writeFuture + return channel.writeAndFlush(GRPCClientRequestPart.message(_Box(message))) } } - public func sendMessage(_ message: RequestMessage, promise: EventLoopPromise?, flush: Bool = true) { + public func sendMessage(_ message: RequestMessage, promise: EventLoopPromise?) { self.subchannel.whenSuccess { channel in - channel.write(GRPCClientRequestPart.message(_Box(message)), promise: promise) - if flush { + channel.writeAndFlush(GRPCClientRequestPart.message(_Box(message)), promise: promise) + } + } + + public func sendMessages(_ messages: S) -> EventLoopFuture where S.Element == RequestMessage { + return self.subchannel.flatMap { channel -> EventLoopFuture in + let writeFutures = messages.map { message in + channel.write(GRPCClientRequestPart.message(_Box(message))) + } + channel.flush() + return EventLoopFuture.andAllSucceed(writeFutures, on: channel.eventLoop) + } + } + + public func sendMessages(_ messages: S, promise: EventLoopPromise?) where S.Element == RequestMessage { + if let promise = promise { + self.sendMessages(messages).cascade(to: promise) + } else { + self.subchannel.whenSuccess { channel in + for message in messages { + channel.write(GRPCClientRequestPart.message(_Box(message)), promise: nil) + } channel.flush() } } diff --git a/Tests/GRPCTests/StreamingRequestClientCallTests.swift b/Tests/GRPCTests/StreamingRequestClientCallTests.swift new file mode 100644 index 000000000..52c50416e --- /dev/null +++ b/Tests/GRPCTests/StreamingRequestClientCallTests.swift @@ -0,0 +1,57 @@ +/* + * Copyright 2019, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Foundation +import GRPC +import XCTest + +class StreamingRequestClientCallTests: EchoTestCaseBase { + class ResponseCounter { + var expectation: XCTestExpectation + + init(expectation: XCTestExpectation) { + self.expectation = expectation + } + + func increment() { + self.expectation.fulfill() + } + } + + func testSendMessages() throws { + let messagesReceived = self.expectation(description: "messages received") + let counter = ResponseCounter(expectation: messagesReceived) + + let update = self.client.update { _ in + counter.increment() + } + + // Send the first batch. + let requests = ["foo", "bar", "baz"].map { Echo_EchoRequest(text: $0) } + messagesReceived.expectedFulfillmentCount = requests.count + XCTAssertNoThrow(try update.sendMessages(requests).wait()) + + // Wait for the responses. + self.wait(for: [messagesReceived], timeout: 0.5) + + let statusReceived = self.expectation(description: "status received") + update.status.map { $0.code }.assertEqual(.ok, fulfill: statusReceived) + + // End the call. + update.sendEnd(promise: nil) + + self.wait(for: [statusReceived], timeout: 0.5) + } +} diff --git a/Tests/GRPCTests/XCTestManifests.swift b/Tests/GRPCTests/XCTestManifests.swift index cff2a1a15..d77b81265 100644 --- a/Tests/GRPCTests/XCTestManifests.swift +++ b/Tests/GRPCTests/XCTestManifests.swift @@ -433,6 +433,15 @@ extension ServerWebTests { ] } +extension StreamingRequestClientCallTests { + // DO NOT MODIFY: This is autogenerated, use: + // `swift test --generate-linuxmain` + // to regenerate. + static let __allTests__StreamingRequestClientCallTests = [ + ("testSendMessages", testSendMessages), + ] +} + public func __allTests() -> [XCTestCaseEntry] { return [ testCase(AnyServiceClientTests.__allTests__AnyServiceClientTests), @@ -464,6 +473,7 @@ public func __allTests() -> [XCTestCaseEntry] { testCase(ServerErrorTransformingTests.__allTests__ServerErrorTransformingTests), testCase(ServerThrowingTests.__allTests__ServerThrowingTests), testCase(ServerWebTests.__allTests__ServerWebTests), + testCase(StreamingRequestClientCallTests.__allTests__StreamingRequestClientCallTests), ] } #endif