1818package org .apache .spark .network .util ;
1919
2020import java .nio .ByteBuffer ;
21+ import java .util .ArrayList ;
22+ import java .util .List ;
2123import java .util .Random ;
24+ import java .util .concurrent .atomic .AtomicInteger ;
2225
2326import io .netty .buffer .ByteBuf ;
2427import io .netty .buffer .Unpooled ;
2528import io .netty .channel .ChannelHandlerContext ;
29+ import org .junit .AfterClass ;
2630import org .junit .Test ;
31+ import org .mockito .invocation .InvocationOnMock ;
32+ import org .mockito .stubbing .Answer ;
2733import static org .junit .Assert .*;
2834import static org .mockito .Mockito .*;
2935
3036public class TransportFrameDecoderSuite {
3137
38+ private static Random RND = new Random ();
39+
40+ @ AfterClass
41+ public static void cleanup () {
42+ RND = null ;
43+ }
44+
3245 @ Test
3346 public void testFrameDecoding () throws Exception {
34- Random rnd = new Random ();
3547 TransportFrameDecoder decoder = new TransportFrameDecoder ();
36- ChannelHandlerContext ctx = mock (ChannelHandlerContext .class );
37-
38- final int frameCount = 100 ;
39- ByteBuf data = Unpooled .buffer ();
40- try {
41- for (int i = 0 ; i < frameCount ; i ++) {
42- byte [] frame = new byte [1024 * (rnd .nextInt (31 ) + 1 )];
43- data .writeLong (frame .length + 8 );
44- data .writeBytes (frame );
45- }
46-
47- while (data .isReadable ()) {
48- int size = rnd .nextInt (16 * 1024 ) + 256 ;
49- decoder .channelRead (ctx , data .readSlice (Math .min (data .readableBytes (), size )));
50- }
51-
52- verify (ctx , times (frameCount )).fireChannelRead (any (ByteBuf .class ));
53- } finally {
54- data .release ();
55- }
48+ ChannelHandlerContext ctx = mockChannelHandlerContext ();
49+ ByteBuf data = createAndFeedFrames (100 , decoder , ctx );
50+ verifyAndCloseDecoder (decoder , ctx , data );
5651 }
5752
5853 @ Test
5954 public void testInterception () throws Exception {
6055 final int interceptedReads = 3 ;
6156 TransportFrameDecoder decoder = new TransportFrameDecoder ();
6257 TransportFrameDecoder .Interceptor interceptor = spy (new MockInterceptor (interceptedReads ));
63- ChannelHandlerContext ctx = mock ( ChannelHandlerContext . class );
58+ ChannelHandlerContext ctx = mockChannelHandlerContext ( );
6459
6560 byte [] data = new byte [8 ];
6661 ByteBuf len = Unpooled .copyLong (8 + data .length );
@@ -70,16 +65,56 @@ public void testInterception() throws Exception {
7065 decoder .setInterceptor (interceptor );
7166 for (int i = 0 ; i < interceptedReads ; i ++) {
7267 decoder .channelRead (ctx , dataBuf );
73- dataBuf .release ( );
68+ assertEquals ( 0 , dataBuf .refCnt () );
7469 dataBuf = Unpooled .wrappedBuffer (data );
7570 }
7671 decoder .channelRead (ctx , len );
7772 decoder .channelRead (ctx , dataBuf );
7873 verify (interceptor , times (interceptedReads )).handle (any (ByteBuf .class ));
7974 verify (ctx ).fireChannelRead (any (ByteBuffer .class ));
75+ assertEquals (0 , len .refCnt ());
76+ assertEquals (0 , dataBuf .refCnt ());
8077 } finally {
81- len .release ();
82- dataBuf .release ();
78+ release (len );
79+ release (dataBuf );
80+ }
81+ }
82+
83+ @ Test
84+ public void testRetainedFrames () throws Exception {
85+ TransportFrameDecoder decoder = new TransportFrameDecoder ();
86+
87+ final AtomicInteger count = new AtomicInteger ();
88+ final List <ByteBuf > retained = new ArrayList <>();
89+
90+ ChannelHandlerContext ctx = mock (ChannelHandlerContext .class );
91+ when (ctx .fireChannelRead (any ())).thenAnswer (new Answer <Void >() {
92+ @ Override
93+ public Void answer (InvocationOnMock in ) {
94+ // Retain a few frames but not others.
95+ ByteBuf buf = (ByteBuf ) in .getArguments ()[0 ];
96+ if (count .incrementAndGet () % 2 == 0 ) {
97+ retained .add (buf );
98+ } else {
99+ buf .release ();
100+ }
101+ return null ;
102+ }
103+ });
104+
105+ ByteBuf data = createAndFeedFrames (100 , decoder , ctx );
106+ try {
107+ // Verify all retained buffers are readable.
108+ for (ByteBuf b : retained ) {
109+ byte [] tmp = new byte [b .readableBytes ()];
110+ b .readBytes (tmp );
111+ b .release ();
112+ }
113+ verifyAndCloseDecoder (decoder , ctx , data );
114+ } finally {
115+ for (ByteBuf b : retained ) {
116+ release (b );
117+ }
83118 }
84119 }
85120
@@ -100,6 +135,47 @@ public void testLargeFrame() throws Exception {
100135 testInvalidFrame (Integer .MAX_VALUE + 9 );
101136 }
102137
138+ /**
139+ * Creates a number of randomly sized frames and feed them to the given decoder, verifying
140+ * that the frames were read.
141+ */
142+ private ByteBuf createAndFeedFrames (
143+ int frameCount ,
144+ TransportFrameDecoder decoder ,
145+ ChannelHandlerContext ctx ) throws Exception {
146+ ByteBuf data = Unpooled .buffer ();
147+ for (int i = 0 ; i < frameCount ; i ++) {
148+ byte [] frame = new byte [1024 * (RND .nextInt (31 ) + 1 )];
149+ data .writeLong (frame .length + 8 );
150+ data .writeBytes (frame );
151+ }
152+
153+ try {
154+ while (data .isReadable ()) {
155+ int size = RND .nextInt (4 * 1024 ) + 256 ;
156+ decoder .channelRead (ctx , data .readSlice (Math .min (data .readableBytes (), size )).retain ());
157+ }
158+
159+ verify (ctx , times (frameCount )).fireChannelRead (any (ByteBuf .class ));
160+ } catch (Exception e ) {
161+ release (data );
162+ throw e ;
163+ }
164+ return data ;
165+ }
166+
167+ private void verifyAndCloseDecoder (
168+ TransportFrameDecoder decoder ,
169+ ChannelHandlerContext ctx ,
170+ ByteBuf data ) throws Exception {
171+ try {
172+ decoder .channelInactive (ctx );
173+ assertTrue ("There shouldn't be dangling references to the data." , data .release ());
174+ } finally {
175+ release (data );
176+ }
177+ }
178+
103179 private void testInvalidFrame (long size ) throws Exception {
104180 TransportFrameDecoder decoder = new TransportFrameDecoder ();
105181 ChannelHandlerContext ctx = mock (ChannelHandlerContext .class );
@@ -111,6 +187,25 @@ private void testInvalidFrame(long size) throws Exception {
111187 }
112188 }
113189
190+ private ChannelHandlerContext mockChannelHandlerContext () {
191+ ChannelHandlerContext ctx = mock (ChannelHandlerContext .class );
192+ when (ctx .fireChannelRead (any ())).thenAnswer (new Answer <Void >() {
193+ @ Override
194+ public Void answer (InvocationOnMock in ) {
195+ ByteBuf buf = (ByteBuf ) in .getArguments ()[0 ];
196+ buf .release ();
197+ return null ;
198+ }
199+ });
200+ return ctx ;
201+ }
202+
203+ private void release (ByteBuf buf ) {
204+ if (buf .refCnt () > 0 ) {
205+ buf .release (buf .refCnt ());
206+ }
207+ }
208+
114209 private static class MockInterceptor implements TransportFrameDecoder .Interceptor {
115210
116211 private int remainingReads ;
0 commit comments