Skip to content

Commit e12ecfa

Browse files
author
Marcelo Vanzin
committed
[SPARK-11617][NETWORK] Fix leak in TransportFrameDecoder.
The code was using the wrong API to add data to the internal composite buffer, causing buffers to leak in certain situations. Use the right API and enhance the tests to catch memory leaks. Also, avoid reusing the composite buffers when downstream handlers keep references to them; this seems to cause a few different issues even though the ref counting code seems to be correct, so instead pay the cost of copying a few bytes when that situation happens. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #9619 from vanzin/SPARK-11617. (cherry picked from commit 540bf58) Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>
1 parent 505ecee commit e12ecfa

File tree

2 files changed

+151
-41
lines changed

2 files changed

+151
-41
lines changed

network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,32 +56,43 @@ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception
5656
buffer = in.alloc().compositeBuffer();
5757
}
5858

59-
buffer.writeBytes(in);
59+
buffer.addComponent(in).writerIndex(buffer.writerIndex() + in.readableBytes());
6060

6161
while (buffer.isReadable()) {
62-
feedInterceptor();
63-
if (interceptor != null) {
64-
continue;
65-
}
62+
discardReadBytes();
63+
if (!feedInterceptor()) {
64+
ByteBuf frame = decodeNext();
65+
if (frame == null) {
66+
break;
67+
}
6668

67-
ByteBuf frame = decodeNext();
68-
if (frame != null) {
6969
ctx.fireChannelRead(frame);
70-
} else {
71-
break;
7270
}
7371
}
7472

75-
// We can't discard read sub-buffers if there are other references to the buffer (e.g.
76-
// through slices used for framing). This assumes that code that retains references
77-
// will call retain() from the thread that called "fireChannelRead()" above, otherwise
78-
// ref counting will go awry.
79-
if (buffer != null && buffer.refCnt() == 1) {
73+
discardReadBytes();
74+
}
75+
76+
private void discardReadBytes() {
77+
// If the buffer's been retained by downstream code, then make a copy of the remaining
78+
// bytes into a new buffer. Otherwise, just discard stale components.
79+
if (buffer.refCnt() > 1) {
80+
CompositeByteBuf newBuffer = buffer.alloc().compositeBuffer();
81+
82+
if (buffer.readableBytes() > 0) {
83+
ByteBuf spillBuf = buffer.alloc().buffer(buffer.readableBytes());
84+
spillBuf.writeBytes(buffer);
85+
newBuffer.addComponent(spillBuf).writerIndex(spillBuf.readableBytes());
86+
}
87+
88+
buffer.release();
89+
buffer = newBuffer;
90+
} else {
8091
buffer.discardReadComponents();
8192
}
8293
}
8394

84-
protected ByteBuf decodeNext() throws Exception {
95+
private ByteBuf decodeNext() throws Exception {
8596
if (buffer.readableBytes() < LENGTH_SIZE) {
8697
return null;
8798
}
@@ -127,10 +138,14 @@ public void setInterceptor(Interceptor interceptor) {
127138
this.interceptor = interceptor;
128139
}
129140

130-
private void feedInterceptor() throws Exception {
141+
/**
142+
* @return Whether the interceptor is still active after processing the data.
143+
*/
144+
private boolean feedInterceptor() throws Exception {
131145
if (interceptor != null && !interceptor.handle(buffer)) {
132146
interceptor = null;
133147
}
148+
return interceptor != null;
134149
}
135150

136151
public static interface Interceptor {

network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java

Lines changed: 120 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,49 +18,44 @@
1818
package org.apache.spark.network.util;
1919

2020
import java.nio.ByteBuffer;
21+
import java.util.ArrayList;
22+
import java.util.List;
2123
import java.util.Random;
24+
import java.util.concurrent.atomic.AtomicInteger;
2225

2326
import io.netty.buffer.ByteBuf;
2427
import io.netty.buffer.Unpooled;
2528
import io.netty.channel.ChannelHandlerContext;
29+
import org.junit.AfterClass;
2630
import org.junit.Test;
31+
import org.mockito.invocation.InvocationOnMock;
32+
import org.mockito.stubbing.Answer;
2733
import static org.junit.Assert.*;
2834
import static org.mockito.Mockito.*;
2935

3036
public class TransportFrameDecoderSuite {
3137

38+
private static Random RND = new Random();
39+
40+
@AfterClass
41+
public static void cleanup() {
42+
RND = null;
43+
}
44+
3245
@Test
3346
public void testFrameDecoding() throws Exception {
34-
Random rnd = new Random();
3547
TransportFrameDecoder decoder = new TransportFrameDecoder();
36-
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
37-
38-
final int frameCount = 100;
39-
ByteBuf data = Unpooled.buffer();
40-
try {
41-
for (int i = 0; i < frameCount; i++) {
42-
byte[] frame = new byte[1024 * (rnd.nextInt(31) + 1)];
43-
data.writeLong(frame.length + 8);
44-
data.writeBytes(frame);
45-
}
46-
47-
while (data.isReadable()) {
48-
int size = rnd.nextInt(16 * 1024) + 256;
49-
decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)));
50-
}
51-
52-
verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class));
53-
} finally {
54-
data.release();
55-
}
48+
ChannelHandlerContext ctx = mockChannelHandlerContext();
49+
ByteBuf data = createAndFeedFrames(100, decoder, ctx);
50+
verifyAndCloseDecoder(decoder, ctx, data);
5651
}
5752

5853
@Test
5954
public void testInterception() throws Exception {
6055
final int interceptedReads = 3;
6156
TransportFrameDecoder decoder = new TransportFrameDecoder();
6257
TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads));
63-
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
58+
ChannelHandlerContext ctx = mockChannelHandlerContext();
6459

6560
byte[] data = new byte[8];
6661
ByteBuf len = Unpooled.copyLong(8 + data.length);
@@ -70,16 +65,56 @@ public void testInterception() throws Exception {
7065
decoder.setInterceptor(interceptor);
7166
for (int i = 0; i < interceptedReads; i++) {
7267
decoder.channelRead(ctx, dataBuf);
73-
dataBuf.release();
68+
assertEquals(0, dataBuf.refCnt());
7469
dataBuf = Unpooled.wrappedBuffer(data);
7570
}
7671
decoder.channelRead(ctx, len);
7772
decoder.channelRead(ctx, dataBuf);
7873
verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class));
7974
verify(ctx).fireChannelRead(any(ByteBuffer.class));
75+
assertEquals(0, len.refCnt());
76+
assertEquals(0, dataBuf.refCnt());
8077
} finally {
81-
len.release();
82-
dataBuf.release();
78+
release(len);
79+
release(dataBuf);
80+
}
81+
}
82+
83+
@Test
84+
public void testRetainedFrames() throws Exception {
85+
TransportFrameDecoder decoder = new TransportFrameDecoder();
86+
87+
final AtomicInteger count = new AtomicInteger();
88+
final List<ByteBuf> retained = new ArrayList<>();
89+
90+
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
91+
when(ctx.fireChannelRead(any())).thenAnswer(new Answer<Void>() {
92+
@Override
93+
public Void answer(InvocationOnMock in) {
94+
// Retain a few frames but not others.
95+
ByteBuf buf = (ByteBuf) in.getArguments()[0];
96+
if (count.incrementAndGet() % 2 == 0) {
97+
retained.add(buf);
98+
} else {
99+
buf.release();
100+
}
101+
return null;
102+
}
103+
});
104+
105+
ByteBuf data = createAndFeedFrames(100, decoder, ctx);
106+
try {
107+
// Verify all retained buffers are readable.
108+
for (ByteBuf b : retained) {
109+
byte[] tmp = new byte[b.readableBytes()];
110+
b.readBytes(tmp);
111+
b.release();
112+
}
113+
verifyAndCloseDecoder(decoder, ctx, data);
114+
} finally {
115+
for (ByteBuf b : retained) {
116+
release(b);
117+
}
83118
}
84119
}
85120

@@ -100,6 +135,47 @@ public void testLargeFrame() throws Exception {
100135
testInvalidFrame(Integer.MAX_VALUE + 9);
101136
}
102137

138+
/**
139+
* Creates a number of randomly sized frames and feed them to the given decoder, verifying
140+
* that the frames were read.
141+
*/
142+
private ByteBuf createAndFeedFrames(
143+
int frameCount,
144+
TransportFrameDecoder decoder,
145+
ChannelHandlerContext ctx) throws Exception {
146+
ByteBuf data = Unpooled.buffer();
147+
for (int i = 0; i < frameCount; i++) {
148+
byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)];
149+
data.writeLong(frame.length + 8);
150+
data.writeBytes(frame);
151+
}
152+
153+
try {
154+
while (data.isReadable()) {
155+
int size = RND.nextInt(4 * 1024) + 256;
156+
decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain());
157+
}
158+
159+
verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class));
160+
} catch (Exception e) {
161+
release(data);
162+
throw e;
163+
}
164+
return data;
165+
}
166+
167+
private void verifyAndCloseDecoder(
168+
TransportFrameDecoder decoder,
169+
ChannelHandlerContext ctx,
170+
ByteBuf data) throws Exception {
171+
try {
172+
decoder.channelInactive(ctx);
173+
assertTrue("There shouldn't be dangling references to the data.", data.release());
174+
} finally {
175+
release(data);
176+
}
177+
}
178+
103179
private void testInvalidFrame(long size) throws Exception {
104180
TransportFrameDecoder decoder = new TransportFrameDecoder();
105181
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
@@ -111,6 +187,25 @@ private void testInvalidFrame(long size) throws Exception {
111187
}
112188
}
113189

190+
private ChannelHandlerContext mockChannelHandlerContext() {
191+
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
192+
when(ctx.fireChannelRead(any())).thenAnswer(new Answer<Void>() {
193+
@Override
194+
public Void answer(InvocationOnMock in) {
195+
ByteBuf buf = (ByteBuf) in.getArguments()[0];
196+
buf.release();
197+
return null;
198+
}
199+
});
200+
return ctx;
201+
}
202+
203+
private void release(ByteBuf buf) {
204+
if (buf.refCnt() > 0) {
205+
buf.release(buf.refCnt());
206+
}
207+
}
208+
114209
private static class MockInterceptor implements TransportFrameDecoder.Interceptor {
115210

116211
private int remainingReads;

0 commit comments

Comments
 (0)