Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new protocol for ChannelHandler to get buffered bytes in the channel handler #2918

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions Sources/NIOCore/ChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,19 @@ extension RemovableChannelHandler {
context.leavePipeline(removalToken: removalToken)
}
}

/// A `OutboundBufferedBytesAuditableChannelHandler` is a `ChannelHandler` that
/// audits and reports the number of bytes buffered for outbound direction.
public protocol OutboundBufferedBytesAuditableChannelHandler {
johnnzhou marked this conversation as resolved.
Show resolved Hide resolved
/// Returns the number of bytes buffered in the channel handler, which are queued to be sent to
/// the next outbound channel handler.
func auditOutboundBufferedBytes() -> Int
johnnzhou marked this conversation as resolved.
Show resolved Hide resolved
}

/// A `InboundBufferedBytesAuditableChannelHandler` is a `ChannelHandler` that
/// audits and reports the number of bytes buffered for inbound direction.
public protocol InboundBufferedBytesAuditableChannelHandler {
johnnzhou marked this conversation as resolved.
Show resolved Hide resolved
/// Returns the number of bytes buffered in the channel handler, which are queued to be sent to
/// the next inbound channel handler.
func auditInboundBufferedBytes() -> Int
}
170 changes: 170 additions & 0 deletions Sources/NIOCore/ChannelPipeline.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1492,6 +1492,8 @@ public enum ChannelPipelineError: Error {
case alreadyRemoved
/// `ChannelHandler` was not found.
case notFound
/// `ChannelHandler` is not auditable.
case notAuditable
johnnzhou marked this conversation as resolved.
Show resolved Hide resolved
}

/// Every `ChannelHandler` has -- when added to a `ChannelPipeline` -- a corresponding `ChannelHandlerContext` which is
Expand Down Expand Up @@ -2089,3 +2091,171 @@ extension ChannelPipeline: CustomDebugStringConvertible {
return handlers
}
}

extension ChannelPipeline {
FranzBusch marked this conversation as resolved.
Show resolved Hide resolved
private enum AuditDirection: Equatable {
case inbound
case outbound
}

/// Audit the total number of bytes buffered for outbound.
johnnzhou marked this conversation as resolved.
Show resolved Hide resolved
public func auditOutboundBufferedBytes() -> EventLoopFuture<Int> {
let future: EventLoopFuture<Int>

if self.eventLoop.inEventLoop {
future = self.eventLoop.makeSucceededFuture(auditAll(direction: .outbound))
} else {
future = self.eventLoop.submit {
self.auditAll(direction: .outbound)
}
}

return future
}

/// Audit the total number of bytes buffered for inbound.
public func auditInboundBufferedBytes() -> EventLoopFuture<Int> {
let future: EventLoopFuture<Int>

if self.eventLoop.inEventLoop {
future = self.eventLoop.makeSucceededFuture(auditAll(direction: .inbound))
} else {
future = self.eventLoop.submit {
self.auditAll(direction: .inbound)
}
}

return future
}

/// Audit the total numbers of outbound bytes buffered in the channel handler with the given `name`.
///
/// - parameters:
/// - name: the name of the channel handler whose outbound buffered bytes will be audited.
public func auditOutboundBufferedBytes(name: String) -> EventLoopFuture<Int> {
johnnzhou marked this conversation as resolved.
Show resolved Hide resolved
let future: EventLoopFuture<Int>

if self.eventLoop.inEventLoop {
future = self.eventLoop.makeCompletedFuture(audit0(name: name, direction: .outbound))
} else {
future = self.eventLoop.submit {
try self.audit0(name: name, direction: .outbound).get()
}
}

return future
}

/// Audit the total numbers of outbound bytes buffered in the channel handler with the given `handler`.
///
/// - parameters:
/// - handler: the channel handler object whose outbound buffered bytes will be audited.
public func auditOutboundBufferedBytes(handler: ChannelHandler) -> EventLoopFuture<Int> {
johnnzhou marked this conversation as resolved.
Show resolved Hide resolved
let future: EventLoopFuture<Int>

if self.eventLoop.inEventLoop {
future = self.eventLoop.makeCompletedFuture(audit0(handler: handler, direction: .outbound))
} else {
future = self.eventLoop.submit {
try self.audit0(handler: handler, direction: .outbound).get()
}
}

return future
}

/// Audit the total numbers of inbound bytes buffered in the channel handler with the given `name`.
///
/// - parameters:
/// - name: the name of the channel handler whose inbound buffered bytes will be audited.
public func auditInboundBufferedBytes(name: String) -> EventLoopFuture<Int> {
let future: EventLoopFuture<Int>

if self.eventLoop.inEventLoop {
future = self.eventLoop.makeCompletedFuture(audit0(name: name, direction: .inbound))
} else {
future = self.eventLoop.submit {
try self.audit0(name: name, direction: .inbound).get()
}
}

return future
}

/// Audit the total numbers of inbound bytes buffered in the channel handler with the given `handler`.
///
/// - parameters:
/// - handler: the channel handler object whose inbound buffered bytes will be audited.
public func auditInboundBufferedBytes(handler: ChannelHandler) -> EventLoopFuture<Int> {
let future: EventLoopFuture<Int>

if self.eventLoop.inEventLoop {
future = self.eventLoop.makeCompletedFuture(audit0(handler: handler, direction: .inbound))
} else {
future = self.eventLoop.submit {
try self.audit0(handler: handler, direction: .inbound).get()
}
}

return future
}

private func audit0(name: String, direction: AuditDirection) -> Result<Int, Error> {
let result = self.contextSync(name: name)
switch result {
case .success(let context):
return audit0(context: context, direction: direction)
case .failure(let error):
return .failure(error)
}
}

private func audit0(handler: ChannelHandler, direction: AuditDirection) -> Result<Int, Error> {
let result = self.contextSync(handler: handler)
switch result {
case .success(let context):
return audit0(context: context, direction: direction)
case .failure(let error):
return .failure(error)
}
}

private func audit0(context: ChannelHandlerContext, direction: AuditDirection) -> Result<Int, Error> {
johnnzhou marked this conversation as resolved.
Show resolved Hide resolved
switch direction {
case .inbound:
guard let handler = context.handler as? InboundBufferedBytesAuditableChannelHandler else {
return .failure(ChannelPipelineError.notAuditable)
}
return .success(handler.auditInboundBufferedBytes())
case .outbound:
guard let handler = context.handler as? OutboundBufferedBytesAuditableChannelHandler else {
return .failure(ChannelPipelineError.notAuditable)
}
return .success(handler.auditOutboundBufferedBytes())
}

}

private func auditAll(direction: AuditDirection) -> Int {
var total = 0
var current = self.head?.next
switch direction {
case .inbound:
while let c = current, c !== self.tail {
if let inboundHandler = c.handler as? InboundBufferedBytesAuditableChannelHandler {
total += inboundHandler.auditInboundBufferedBytes()
}
current = current?.next
}
case .outbound:
while let c = current, c !== self.tail {
if let outboundHandler = c.handler as? OutboundBufferedBytesAuditableChannelHandler {
total += outboundHandler.auditOutboundBufferedBytes()
}
current = current?.next
}
}

return total
}
}