Skip to content

Commit

Permalink
Call Freeable.free() if a Freeable message reaches the end of the Cha…
Browse files Browse the repository at this point in the history
…nnelPipeline to guard against resource leakage
  • Loading branch information
Norman Maurer committed Jan 7, 2013
1 parent cf2fbf7 commit 2659547
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import io.netty.buffer.Buf;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Freeable;
import io.netty.buffer.MessageBuf;
import io.netty.buffer.Unpooled;
import io.netty.logging.InternalLogger;
Expand Down Expand Up @@ -48,6 +49,8 @@ final class DefaultChannelPipeline implements ChannelPipeline {

final DefaultChannelHandlerContext head;
private volatile DefaultChannelHandlerContext tail;
private final DefaultChannelHandlerContext tailCtx;

private final Map<String, DefaultChannelHandlerContext> name2ctx =
new HashMap<String, DefaultChannelHandlerContext>(4);
private boolean firedChannelActive;
Expand All @@ -56,16 +59,21 @@ final class DefaultChannelPipeline implements ChannelPipeline {
final Map<EventExecutorGroup, EventExecutor> childExecutors =
new IdentityHashMap<EventExecutorGroup, EventExecutor>();

private static final TailHandler TAIL_HANDLER = new TailHandler();

public DefaultChannelPipeline(Channel channel) {
if (channel == null) {
throw new NullPointerException("channel");
}
this.channel = channel;

HeadHandler headHandler = new HeadHandler();
tailCtx = new DefaultChannelHandlerContext(
this, null, null, null, generateName(TAIL_HANDLER), TAIL_HANDLER);
head = new DefaultChannelHandlerContext(
this, null, null, null, generateName(headHandler), headHandler);
tail = head;
this, null, null, tailCtx, generateName(headHandler), headHandler);
tailCtx.prev = head;
tail = tailCtx;

unsafe = channel.unsafe();
}
Expand Down Expand Up @@ -119,10 +127,12 @@ private void addFirst0(
if (nextCtx != null) {
nextCtx.prev = newCtx;
}
head.next = newCtx;
if (tail == head) {
if (head.next == tailCtx) {
tail = newCtx;
newCtx.next = tailCtx;
tailCtx.prev = newCtx;
}
head.next = newCtx;

name2ctx.put(name, newCtx);

Expand All @@ -143,8 +153,7 @@ public ChannelPipeline addLast(EventExecutorGroup group, final String name, Chan
checkDuplicateName(name);

oldTail = tail;
newTail = new DefaultChannelHandlerContext(this, group, oldTail, null, name, handler);

newTail = new DefaultChannelHandlerContext(this, group, null, null, name, handler);
if (!newTail.channel().isRegistered() || newTail.executor().inEventLoop()) {
addLast0(name, oldTail, newTail);
return this;
Expand All @@ -171,7 +180,21 @@ private void addLast0(
final String name, DefaultChannelHandlerContext oldTail, DefaultChannelHandlerContext newTail) {
callBeforeAdd(newTail);

oldTail.next = newTail;
DefaultChannelHandlerContext prev = oldTail.prev;
if (oldTail == tailCtx) {
// This is the first handler added
tailCtx.prev = newTail;
newTail.next = tailCtx;
prev.next = newTail;
newTail.prev = prev;
} else {
oldTail.next = newTail;
newTail.prev = oldTail;

prev.next = oldTail;
oldTail.prev = prev;
}

tail = newTail;
name2ctx.put(name, newTail);

Expand Down Expand Up @@ -361,12 +384,15 @@ private DefaultChannelHandlerContext remove(final DefaultChannelHandlerContext c
Future<?> future;

synchronized (this) {
if (ctx == tailCtx) {
throw new NoSuchElementException();
}
if (head == tail) {
return null;
} else if (ctx == head) {
throw new Error(); // Should never happen.
} else if (ctx == tail) {
if (head == tail) {
if (tail == tailCtx) {
throw new NoSuchElementException();
}

Expand Down Expand Up @@ -425,7 +451,7 @@ private void remove0(DefaultChannelHandlerContext ctx) {

@Override
public ChannelHandler removeFirst() {
if (head == tail) {
if (head.next == tailCtx) {
throw new NoSuchElementException();
}
return remove(head.next).handler();
Expand All @@ -436,7 +462,7 @@ public ChannelHandler removeLast() {
final DefaultChannelHandlerContext oldTail;

synchronized (this) {
if (head == tail) {
if (tail == tailCtx) {
throw new NoSuchElementException();
}
oldTail = tail;
Expand Down Expand Up @@ -464,7 +490,9 @@ public void run() {
private void removeLast0(DefaultChannelHandlerContext oldTail) {
callBeforeRemove(oldTail);

oldTail.prev.next = null;
tailCtx.prev = oldTail.prev;
oldTail.prev.next = tailCtx;

tail = oldTail.prev;
name2ctx.remove(oldTail.name());

Expand Down Expand Up @@ -493,10 +521,13 @@ private ChannelHandler replace(
final DefaultChannelHandlerContext ctx, final String newName, ChannelHandler newHandler) {
Future<?> future;
synchronized (this) {
if (ctx == tailCtx) {
throw new NoSuchElementException();
}
if (ctx == head) {
throw new IllegalArgumentException();
} else if (ctx == tail) {
if (head == tail) {
if (tail == tailCtx) {
throw new NoSuchElementException();
}
final DefaultChannelHandlerContext oldTail = tail;
Expand Down Expand Up @@ -688,7 +719,7 @@ public ChannelHandlerContext firstContext() {
@Override
public ChannelHandler last() {
DefaultChannelHandlerContext last = tail;
if (last == head || last == null) {
if (last == tailCtx || last == null) {
return null;
}
return last.handler();
Expand Down Expand Up @@ -743,6 +774,7 @@ public ChannelHandlerContext context(ChannelHandler handler) {

DefaultChannelHandlerContext ctx = head.next;
for (;;) {

if (ctx == null) {
return null;
}
Expand Down Expand Up @@ -791,7 +823,7 @@ public Map<String, ChannelHandler> toMap() {
Map<String, ChannelHandler> map = new LinkedHashMap<String, ChannelHandler>();
DefaultChannelHandlerContext ctx = head.next;
for (;;) {
if (ctx == null) {
if (ctx == null || ctx == tailCtx) {
return map;
}
map.put(ctx.name(), ctx.handler());
Expand Down Expand Up @@ -1331,7 +1363,6 @@ ChannelFuture write(DefaultChannelHandlerContext ctx, final Object message, fina

ctx = ctx.prev;
}

if (executor.inEventLoop()) {
write0(ctx, message, promise, msgBuf);
return promise;
Expand Down Expand Up @@ -1483,6 +1514,21 @@ private DefaultChannelHandlerContext getContextOrDie(Class<? extends ChannelHand
}
}

private static final class TailHandler extends ChannelInboundMessageHandlerAdapter<Freeable> {
public TailHandler() {
super(Freeable.class);
}

@Override
protected void messageReceived(ChannelHandlerContext ctx, Freeable msg) throws Exception {
if (logger.isWarnEnabled()) {
logger.warn("Freeable reached end-of-pipeline, call " + msg + ".free() to" +
" guard against resource leakage!");
}
msg.free();
}
}

private final class HeadHandler implements ChannelOutboundHandler {
@Override
public Buf newOutboundBuffer(ChannelHandlerContext ctx) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,89 @@
*/
package io.netty.channel;


import io.netty.buffer.Freeable;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalEventLoopGroup;
import org.junit.Test;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import static org.junit.Assert.*;

public class DefaultChannelPipelineTest {
@Test
public void testFreeCalled() throws InterruptedException{
final CountDownLatch free = new CountDownLatch(1);

Freeable holder = new Freeable() {
@Override
public void free() {
free.countDown();
}

@Override
public boolean isFreed() {
return free.getCount() == 0;
}
};
LocalChannel channel = new LocalChannel();
LocalEventLoopGroup group = new LocalEventLoopGroup();
group.register(channel).awaitUninterruptibly();
DefaultChannelPipeline pipeline = new DefaultChannelPipeline(channel);

StringInboundHandler handler = new StringInboundHandler();
pipeline.addLast(handler);
pipeline.fireChannelActive();
pipeline.inboundMessageBuffer().add(holder);
pipeline.fireInboundBufferUpdated();

assertTrue(free.await(10, TimeUnit.SECONDS));
assertTrue(handler.called);
}

private static final class StringInboundHandler extends ChannelInboundMessageHandlerAdapter<String> {
boolean called;

public StringInboundHandler() {
super(String.class);
}

@Override
public boolean isSupported(Object msg) throws Exception {
called = true;
return super.isSupported(msg);
}

@Override
protected void messageReceived(ChannelHandlerContext ctx, String msg) throws Exception {
fail();
}
}


@Test
public void testRemoveChannelHandler() {
DefaultChannelPipeline pipeline = new DefaultChannelPipeline(new LocalChannel());

ChannelHandler handler1 = newHandler();
ChannelHandler handler2 = newHandler();
ChannelHandler handler3 = newHandler();

pipeline.addLast("handler1", handler1);
pipeline.addLast("handler2", handler2);
pipeline.addLast("handler3", handler3);
assertSame(pipeline.get("handler1"), handler1);
assertSame(pipeline.get("handler2"), handler2);
assertSame(pipeline.get("handler3"), handler3);

pipeline.remove(handler1);
pipeline.remove(handler2);
pipeline.remove(handler3);
}

@Test
public void testReplaceChannelHandler() {
DefaultChannelPipeline pipeline = new DefaultChannelPipeline(new LocalChannel());
Expand Down Expand Up @@ -107,8 +183,11 @@ public void testChannelHandlerContextOrder() {
while (ctx != null) {
int i = toInt(ctx.name());
int j = next(ctx);

assertTrue(i < j);
if (j != -1) {
assertTrue(i < j);
} else {
assertNull(ctx.next.next);
}
ctx = ctx.next;
}

Expand All @@ -125,7 +204,11 @@ private static int next(DefaultChannelHandlerContext ctx) {
}

private static int toInt(String name) {
return Integer.parseInt(name);
try {
return Integer.parseInt(name);
} catch (NumberFormatException e) {
return -1;
}
}

private static void verifyContextNumber(DefaultChannelPipeline pipeline, int expectedNumber) {
Expand Down

0 comments on commit 2659547

Please sign in to comment.