@@ -28,35 +28,46 @@ import org.apache.spark.sql.execution.SparkSqlSerializer
2828class ColumnTypeSuite extends FunSuite {
2929 val DEFAULT_BUFFER_SIZE = 512
3030
31- val columnTypes = Seq (INT , SHORT , LONG , BYTE , DOUBLE , FLOAT , STRING , BINARY , GENERIC )
32-
3331 test(" defaultSize" ) {
34- val defaultSize = Seq (4 , 2 , 8 , 1 , 8 , 4 , 8 , 16 , 16 )
32+ val checks = Map (
33+ INT -> 4 , SHORT -> 2 , LONG -> 8 , BYTE -> 1 , DOUBLE -> 8 , FLOAT -> 4 ,
34+ BOOLEAN -> 1 , STRING -> 8 , BINARY -> 16 , GENERIC -> 16 )
3535
36- columnTypes.zip(defaultSize).foreach { case (columnType, size) =>
37- assert(columnType.defaultSize === size)
36+ checks.foreach { case (columnType, expectedSize) =>
37+ expectResult(expectedSize, s " Wrong defaultSize for $columnType" ) {
38+ columnType.defaultSize
39+ }
3840 }
3941 }
4042
4143 test(" actualSize" ) {
42- val expectedSizes = Seq (4 , 2 , 8 , 1 , 8 , 4 , 4 + 5 , 4 + 4 , 4 + 11 )
43- val actualSizes = Seq (
44- INT .actualSize(Int .MaxValue ),
45- SHORT .actualSize(Short .MaxValue ),
46- LONG .actualSize(Long .MaxValue ),
47- BYTE .actualSize(Byte .MaxValue ),
48- DOUBLE .actualSize(Double .MaxValue ),
49- FLOAT .actualSize(Float .MaxValue ),
50- STRING .actualSize(" hello" ),
51- BINARY .actualSize(new Array [Byte ](4 )),
52- GENERIC .actualSize(SparkSqlSerializer .serialize(Map (1 -> " a" ))))
53-
54- expectedSizes.zip(actualSizes).foreach { case (expected, actual) =>
55- assert(expected === actual)
44+ def checkActualSize [T <: DataType , JvmType ](
45+ columnType : ColumnType [T , JvmType ],
46+ value : JvmType ,
47+ expected : Int ) {
48+
49+ expectResult(expected, s " Wrong actualSize for $columnType" ) {
50+ columnType.actualSize(value)
51+ }
5652 }
53+
54+ checkActualSize(INT , Int .MaxValue , 4 )
55+ checkActualSize(SHORT , Short .MaxValue , 2 )
56+ checkActualSize(LONG , Long .MaxValue , 8 )
57+ checkActualSize(BYTE , Byte .MaxValue , 1 )
58+ checkActualSize(DOUBLE , Double .MaxValue , 8 )
59+ checkActualSize(FLOAT , Float .MaxValue , 4 )
60+ checkActualSize(BOOLEAN , true , 1 )
61+ checkActualSize(STRING , " hello" , 4 + 5 )
62+
63+ val binary = Array .fill[Byte ](4 )(0 : Byte )
64+ checkActualSize(BINARY , binary, 4 + 4 )
65+
66+ val generic = Map (1 -> " a" )
67+ checkActualSize(GENERIC , SparkSqlSerializer .serialize(generic), 4 + 11 )
5768 }
5869
59- testNativeColumnStats [BooleanType .type ](
70+ testNativeColumnType [BooleanType .type ](
6071 BOOLEAN ,
6172 (buffer : ByteBuffer , v : Boolean ) => {
6273 buffer.put((if (v) 1 else 0 ).toByte)
@@ -65,37 +76,19 @@ class ColumnTypeSuite extends FunSuite {
6576 buffer.get() == 1
6677 })
6778
68- testNativeColumnStats[IntegerType .type ](
69- INT ,
70- (_ : ByteBuffer ).putInt(_),
71- (_ : ByteBuffer ).getInt)
72-
73- testNativeColumnStats[ShortType .type ](
74- SHORT ,
75- (_ : ByteBuffer ).putShort(_),
76- (_ : ByteBuffer ).getShort)
77-
78- testNativeColumnStats[LongType .type ](
79- LONG ,
80- (_ : ByteBuffer ).putLong(_),
81- (_ : ByteBuffer ).getLong)
82-
83- testNativeColumnStats[ByteType .type ](
84- BYTE ,
85- (_ : ByteBuffer ).put(_),
86- (_ : ByteBuffer ).get)
87-
88- testNativeColumnStats[DoubleType .type ](
89- DOUBLE ,
90- (_ : ByteBuffer ).putDouble(_),
91- (_ : ByteBuffer ).getDouble)
92-
93- testNativeColumnStats[FloatType .type ](
94- FLOAT ,
95- (_ : ByteBuffer ).putFloat(_),
96- (_ : ByteBuffer ).getFloat)
97-
98- testNativeColumnStats[StringType .type ](
79+ testNativeColumnType[IntegerType .type ](INT , _.putInt(_), _.getInt)
80+
81+ testNativeColumnType[ShortType .type ](SHORT , _.putShort(_), _.getShort)
82+
83+ testNativeColumnType[LongType .type ](LONG , _.putLong(_), _.getLong)
84+
85+ testNativeColumnType[ByteType .type ](BYTE , _.put(_), _.get)
86+
87+ testNativeColumnType[DoubleType .type ](DOUBLE , _.putDouble(_), _.getDouble)
88+
89+ testNativeColumnType[FloatType .type ](FLOAT , _.putFloat(_), _.getFloat)
90+
91+ testNativeColumnType[StringType .type ](
9992 STRING ,
10093 (buffer : ByteBuffer , string : String ) => {
10194 val bytes = string.getBytes()
@@ -108,7 +101,7 @@ class ColumnTypeSuite extends FunSuite {
108101 new String (bytes)
109102 })
110103
111- testColumnStats [BinaryType .type , Array [Byte ]](
104+ testColumnType [BinaryType .type , Array [Byte ]](
112105 BINARY ,
113106 (buffer : ByteBuffer , bytes : Array [Byte ]) => {
114107 buffer.putInt(bytes.length).put(bytes)
@@ -131,51 +124,58 @@ class ColumnTypeSuite extends FunSuite {
131124 val length = buffer.getInt()
132125 assert(length === serializedObj.length)
133126
134- val bytes = new Array [Byte ](length)
135- buffer.get(bytes, 0 , length)
136- assert(obj === SparkSqlSerializer .deserialize(bytes))
127+ expectResult(obj, " Deserialized object didn't equal to the original object" ) {
128+ val bytes = new Array [Byte ](length)
129+ buffer.get(bytes, 0 , length)
130+ SparkSqlSerializer .deserialize(bytes)
131+ }
137132
138133 buffer.rewind()
139134 buffer.putInt(serializedObj.length).put(serializedObj)
140135
141- buffer.rewind()
142- assert(obj === SparkSqlSerializer .deserialize(GENERIC .extract(buffer)))
136+ expectResult(obj, " Deserialized object didn't equal to the original object" ) {
137+ buffer.rewind()
138+ SparkSqlSerializer .deserialize(GENERIC .extract(buffer))
139+ }
143140 }
144141
145- def testNativeColumnStats [T <: NativeType ](
142+ def testNativeColumnType [T <: NativeType ](
146143 columnType : NativeColumnType [T ],
147144 putter : (ByteBuffer , T # JvmType ) => Unit ,
148145 getter : (ByteBuffer ) => T # JvmType ) {
149146
150- testColumnStats [T , T # JvmType ](columnType, putter, getter)
147+ testColumnType [T , T # JvmType ](columnType, putter, getter)
151148 }
152149
153- def testColumnStats [T <: DataType , JvmType ](
150+ def testColumnType [T <: DataType , JvmType ](
154151 columnType : ColumnType [T , JvmType ],
155152 putter : (ByteBuffer , JvmType ) => Unit ,
156153 getter : (ByteBuffer ) => JvmType ) {
157154
158155 val buffer = ByteBuffer .allocate(DEFAULT_BUFFER_SIZE )
159- val columnTypeName = columnType.getClass.getSimpleName.stripSuffix(" $" )
160156 val seq = (0 until 4 ).map(_ => makeRandomValue(columnType))
161157
162- test(s " $columnTypeName .extract " ) {
158+ test(s " $columnType .extract " ) {
163159 buffer.rewind()
164160 seq.foreach(putter(buffer, _))
165161
166162 buffer.rewind()
167- seq.foreach { i =>
168- assert(i === columnType.extract(buffer))
163+ seq.foreach { expected =>
164+ assert(
165+ expected === columnType.extract(buffer),
166+ " Extracted value didn't equal to the original one" )
169167 }
170168 }
171169
172- test(s " $columnTypeName .append " ) {
170+ test(s " $columnType .append " ) {
173171 buffer.rewind()
174172 seq.foreach(columnType.append(_, buffer))
175173
176174 buffer.rewind()
177- seq.foreach { i =>
178- assert(i === getter(buffer))
175+ seq.foreach { expected =>
176+ assert(
177+ expected === getter(buffer),
178+ " Extracted value didn't equal to the original one" )
179179 }
180180 }
181181 }
0 commit comments