@@ -25,19 +25,24 @@ import org.apache.spark.sql.types._
2525import org .apache .spark .unsafe .types .UTF8String
2626
2727class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
28-
29- var testVector : WritableColumnVector = _
30-
31- private def allocate (capacity : Int , dt : DataType ): WritableColumnVector = {
32- new OnHeapColumnVector (capacity, dt)
28+ private def withVector (
29+ vector : WritableColumnVector )(
30+ block : WritableColumnVector => Unit ): Unit = {
31+ try block(vector) finally vector.close()
3332 }
3433
35- override def afterEach (): Unit = {
36- testVector.close()
34+ private def testVectors (
35+ name : String ,
36+ size : Int ,
37+ dt : DataType )(
38+ block : WritableColumnVector => Unit ): Unit = {
39+ test(name) {
40+ withVector(new OnHeapColumnVector (size, dt))(block)
41+ withVector(new OffHeapColumnVector (size, dt))(block)
42+ }
3743 }
3844
39- test(" boolean" ) {
40- testVector = allocate(10 , BooleanType )
45+ testVectors(" boolean" , 10 , BooleanType ) { testVector =>
4146 (0 until 10 ).foreach { i =>
4247 testVector.appendBoolean(i % 2 == 0 )
4348 }
@@ -49,34 +54,31 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
4954 }
5055 }
5156
52- test(" byte" ) {
53- testVector = allocate(10 , ByteType )
57+ testVectors(" byte" , 10 , ByteType ) { testVector =>
5458 (0 until 10 ).foreach { i =>
5559 testVector.appendByte(i.toByte)
5660 }
5761
5862 val array = new ColumnVector .Array (testVector)
5963
6064 (0 until 10 ).foreach { i =>
61- assert(array.get(i, ByteType ) === ( i.toByte) )
65+ assert(array.get(i, ByteType ) === i.toByte)
6266 }
6367 }
6468
65- test(" short" ) {
66- testVector = allocate(10 , ShortType )
69+ testVectors(" short" , 10 , ShortType ) { testVector =>
6770 (0 until 10 ).foreach { i =>
6871 testVector.appendShort(i.toShort)
6972 }
7073
7174 val array = new ColumnVector .Array (testVector)
7275
7376 (0 until 10 ).foreach { i =>
74- assert(array.get(i, ShortType ) === ( i.toShort) )
77+ assert(array.get(i, ShortType ) === i.toShort)
7578 }
7679 }
7780
78- test(" int" ) {
79- testVector = allocate(10 , IntegerType )
81+ testVectors(" int" , 10 , IntegerType ) { testVector =>
8082 (0 until 10 ).foreach { i =>
8183 testVector.appendInt(i)
8284 }
@@ -88,8 +90,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
8890 }
8991 }
9092
91- test(" long" ) {
92- testVector = allocate(10 , LongType )
93+ testVectors(" long" , 10 , LongType ) { testVector =>
9394 (0 until 10 ).foreach { i =>
9495 testVector.appendLong(i)
9596 }
@@ -101,8 +102,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
101102 }
102103 }
103104
104- test(" float" ) {
105- testVector = allocate(10 , FloatType )
105+ testVectors(" float" , 10 , FloatType ) { testVector =>
106106 (0 until 10 ).foreach { i =>
107107 testVector.appendFloat(i.toFloat)
108108 }
@@ -114,8 +114,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
114114 }
115115 }
116116
117- test(" double" ) {
118- testVector = allocate(10 , DoubleType )
117+ testVectors(" double" , 10 , DoubleType ) { testVector =>
119118 (0 until 10 ).foreach { i =>
120119 testVector.appendDouble(i.toDouble)
121120 }
@@ -127,8 +126,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
127126 }
128127 }
129128
130- test(" string" ) {
131- testVector = allocate(10 , StringType )
129+ testVectors(" string" , 10 , StringType ) { testVector =>
132130 (0 until 10 ).map { i =>
133131 val utf8 = s " str $i" .getBytes(" utf8" )
134132 testVector.appendByteArray(utf8, 0 , utf8.length)
@@ -141,8 +139,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
141139 }
142140 }
143141
144- test(" binary" ) {
145- testVector = allocate(10 , BinaryType )
142+ testVectors(" binary" , 10 , BinaryType ) { testVector =>
146143 (0 until 10 ).map { i =>
147144 val utf8 = s " str $i" .getBytes(" utf8" )
148145 testVector.appendByteArray(utf8, 0 , utf8.length)
@@ -156,9 +153,8 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
156153 }
157154 }
158155
159- test(" array" ) {
160- val arrayType = ArrayType (IntegerType , true )
161- testVector = allocate(10 , arrayType)
156+ val arrayType : ArrayType = ArrayType (IntegerType , containsNull = true )
157+ testVectors(" array" , 10 , arrayType) { testVector =>
162158
163159 val data = testVector.arrayData()
164160 var i = 0
@@ -181,9 +177,8 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
181177 assert(array.get(3 , arrayType).asInstanceOf [ArrayData ].toIntArray() === Array (3 , 4 , 5 ))
182178 }
183179
184- test(" struct" ) {
185- val schema = new StructType ().add(" int" , IntegerType ).add(" double" , DoubleType )
186- testVector = allocate(10 , schema)
180+ val structType : StructType = new StructType ().add(" int" , IntegerType ).add(" double" , DoubleType )
181+ testVectors(" struct" , 10 , structType) { testVector =>
187182 val c1 = testVector.getChildColumn(0 )
188183 val c2 = testVector.getChildColumn(1 )
189184 c1.putInt(0 , 123 )
@@ -193,35 +188,34 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
193188
194189 val array = new ColumnVector .Array (testVector)
195190
196- assert(array.get(0 , schema ).asInstanceOf [ColumnarBatch .Row ].get(0 , IntegerType ) === 123 )
197- assert(array.get(0 , schema ).asInstanceOf [ColumnarBatch .Row ].get(1 , DoubleType ) === 3.45 )
198- assert(array.get(1 , schema ).asInstanceOf [ColumnarBatch .Row ].get(0 , IntegerType ) === 456 )
199- assert(array.get(1 , schema ).asInstanceOf [ColumnarBatch .Row ].get(1 , DoubleType ) === 5.67 )
191+ assert(array.get(0 , structType ).asInstanceOf [ColumnarBatch .Row ].get(0 , IntegerType ) === 123 )
192+ assert(array.get(0 , structType ).asInstanceOf [ColumnarBatch .Row ].get(1 , DoubleType ) === 3.45 )
193+ assert(array.get(1 , structType ).asInstanceOf [ColumnarBatch .Row ].get(0 , IntegerType ) === 456 )
194+ assert(array.get(1 , structType ).asInstanceOf [ColumnarBatch .Row ].get(1 , DoubleType ) === 5.67 )
200195 }
201196
202197 test(" [SPARK-22092] off-heap column vector reallocation corrupts array data" ) {
203- val arrayType = ArrayType (IntegerType , true )
204- testVector = new OffHeapColumnVector (8 , arrayType)
198+ withVector(new OffHeapColumnVector (8 , arrayType)) { testVector =>
199+ val data = testVector.arrayData()
200+ (0 until 8 ).foreach(i => data.putInt(i, i))
201+ (0 until 8 ).foreach(i => testVector.putArray(i, i, 1 ))
205202
206- val data = testVector.arrayData()
207- (0 until 8 ).foreach(i => data.putInt(i, i))
208- (0 until 8 ).foreach(i => testVector.putArray(i, i, 1 ))
203+ // Increase vector's capacity and reallocate the data to new bigger buffers.
204+ testVector.reserve(16 )
209205
210- // Increase vector's capacity and reallocate the data to new bigger buffers.
211- testVector.reserve(16 )
212-
213- // Check that none of the values got lost/overwritten.
214- val array = new ColumnVector .Array (testVector)
215- (0 until 8 ).foreach { i =>
216- assert(array.get(i, arrayType).asInstanceOf [ArrayData ].toIntArray() === Array (i))
206+ // Check that none of the values got lost/overwritten.
207+ val array = new ColumnVector .Array (testVector)
208+ (0 until 8 ).foreach { i =>
209+ assert(array.get(i, arrayType).asInstanceOf [ArrayData ].toIntArray() === Array (i))
210+ }
217211 }
218212 }
219213
220214 test(" [SPARK-22092] off-heap column vector reallocation corrupts struct nullability" ) {
221- val structType = new StructType ().add( " int " , IntegerType ).add( " double " , DoubleType )
222- testVector = new OffHeapColumnVector ( 8 , structType )
223- ( 0 until 8 ).foreach(i => if (i % 2 == 0 ) testVector.putNull(i) else testVector.putNotNull(i) )
224- testVector.reserve( 16 )
225- ( 0 until 8 ).foreach(i => assert(testVector.isNullAt(i) == (i % 2 == 0 )))
215+ withVector( new OffHeapColumnVector ( 8 , structType)) { testVector =>
216+ ( 0 until 8 ).foreach(i => if (i % 2 == 0 ) testVector.putNull(i) else testVector.putNotNull(i) )
217+ testVector.reserve( 16 )
218+ ( 0 until 8 ).foreach(i => assert( testVector.isNullAt(i) == (i % 2 == 0 )) )
219+ }
226220 }
227221}
0 commit comments