diff --git a/Sources/NIOCore/ChannelPipeline.swift b/Sources/NIOCore/ChannelPipeline.swift index aa0ab19275..b3c0ef580f 100644 --- a/Sources/NIOCore/ChannelPipeline.swift +++ b/Sources/NIOCore/ChannelPipeline.swift @@ -1161,6 +1161,17 @@ extension ChannelPipeline { return promise.futureResult } + /// Remove a `ChannelHandler` from the `ChannelPipeline`. + /// + /// - parameters: + /// - context: the `ChannelHandlerContext` that belongs to `ChannelHandler` that should be removed. + /// - returns: the `EventLoopFuture` which will be notified once the `ChannelHandler` was removed. + public func removeHandler(context: ChannelHandlerContext) -> EventLoopFuture { + let promise = self.eventLoop.makePromise(of: Void.self) + self._pipeline.removeHandler(context: context, promise: promise) + return promise.futureResult + } + /// Returns the `ChannelHandlerContext` for the given handler instance if it is in /// the `ChannelPipeline`, if it exists. /// diff --git a/Tests/NIOPosixTests/ChannelPipelineTest.swift b/Tests/NIOPosixTests/ChannelPipelineTest.swift index bf7ca89930..b84baf2e07 100644 --- a/Tests/NIOPosixTests/ChannelPipelineTest.swift +++ b/Tests/NIOPosixTests/ChannelPipelineTest.swift @@ -1119,6 +1119,51 @@ class ChannelPipelineTest: XCTestCase { } } + func testRemovingByContexSync() throws { + class Handler: ChannelInboundHandler, RemovableChannelHandler { + typealias InboundIn = Never + + var removeHandlerCalled = false + var withinRemoveHandler = false + + func removeHandler(context: ChannelHandlerContext, removalToken: ChannelHandlerContext.RemovalToken) { + self.removeHandlerCalled = true + self.withinRemoveHandler = true + defer { + self.withinRemoveHandler = false + } + context.leavePipeline(removalToken: removalToken) + } + + func handlerRemoved(context: ChannelHandlerContext) { + XCTAssertTrue(self.removeHandlerCalled) + XCTAssertTrue(self.withinRemoveHandler) + } + } + + let channel = EmbeddedChannel() + defer { + XCTAssertNoThrow(XCTAssertTrue(try channel.finish().isClean)) + } + let allHandlers = [Handler(), Handler(), Handler()] + XCTAssertNoThrow(try channel.pipeline.addHandler(allHandlers[0], name: "the first one to remove").wait()) + XCTAssertNoThrow(try channel.pipeline.addHandler(allHandlers[1], name: "the second one to remove").wait()) + XCTAssertNoThrow(try channel.pipeline.addHandler(allHandlers[2], name: "the last one to remove").wait()) + + let firstContext = try! channel.pipeline.syncOperations.context(name: "the first one to remove") + let secondContext = try! channel.pipeline.syncOperations.context(name: "the second one to remove") + let lastContext = try! channel.pipeline.syncOperations.context(name: "the last one to remove") + + XCTAssertNoThrow(try channel.pipeline.syncOperations.removeHandler(context: firstContext).wait()) + XCTAssertNoThrow(try channel.pipeline.syncOperations.removeHandler(context: secondContext).wait()) + XCTAssertNoThrow(try channel.pipeline.syncOperations.removeHandler(context: lastContext).wait()) + + for handler in allHandlers { + XCTAssertTrue(handler.removeHandlerCalled) + XCTAssertFalse(handler.withinRemoveHandler) + } + } + func testNonRemovableChannelHandlerIsNotRemovable() { class NonRemovableHandler: ChannelInboundHandler { typealias InboundIn = Never