Skip to content

Commit f4d2cbb

Browse files
committed
update
1 parent 55d3178 commit f4d2cbb

File tree

5 files changed

+150
-76
lines changed

5 files changed

+150
-76
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}
2727
import org.apache.spark.annotation.Since
2828
import org.apache.spark.ml.{linalg => newlinalg}
2929
import org.apache.spark.sql.catalyst.InternalRow
30-
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
31-
import org.apache.spark.sql.catalyst.util.{DoubleArrayData, IntArrayData}
30+
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData}
3231
import org.apache.spark.sql.types._
3332

3433
/**
@@ -194,9 +193,9 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
194193
row.setByte(0, 0)
195194
row.setInt(1, sm.numRows)
196195
row.setInt(2, sm.numCols)
197-
row.update(3, new IntArrayData(sm.colPtrs))
198-
row.update(4, new IntArrayData(sm.rowIndices))
199-
row.update(5, new DoubleArrayData(sm.values))
196+
row.update(3, UnsafeArrayData.fromPrimitiveArray(sm.colPtrs))
197+
row.update(4, UnsafeArrayData.fromPrimitiveArray(sm.rowIndices))
198+
row.update(5, UnsafeArrayData.fromPrimitiveArray(sm.values))
200199
row.setBoolean(6, sm.isTransposed)
201200

202201
case dm: DenseMatrix =>
@@ -205,7 +204,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
205204
row.setInt(2, dm.numCols)
206205
row.setNullAt(3)
207206
row.setNullAt(4)
208-
row.update(5, new DoubleArrayData(dm.values))
207+
row.update(5, UnsafeArrayData.fromPrimitiveArray(dm.values))
209208
row.setBoolean(6, dm.isTransposed)
210209
}
211210
row
@@ -219,12 +218,21 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
219218
val tpe = row.getByte(0)
220219
val numRows = row.getInt(1)
221220
val numCols = row.getInt(2)
222-
val values = row.getArray(5).toDoubleArray()
221+
val values = row.getArray(5) match {
222+
case u: UnsafeArrayData => u.toPrimitiveDoubleArray
223+
case a => a.toDoubleArray()
224+
}
223225
val isTransposed = row.getBoolean(6)
224226
tpe match {
225227
case 0 =>
226-
val colPtrs = row.getArray(3).toIntArray()
227-
val rowIndices = row.getArray(4).toIntArray()
228+
val colPtrs = row.getArray(3) match {
229+
case u: UnsafeArrayData => u.toPrimitiveIntArray
230+
case a => a.toIntArray()
231+
}
232+
val rowIndices = row.getArray(4) match {
233+
case u: UnsafeArrayData => u.toPrimitiveIntArray
234+
case a => a.toIntArray()
235+
}
228236
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
229237
case 1 =>
230238
new DenseMatrix(numRows, numCols, values, isTransposed)

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ import org.apache.spark.annotation.{AlphaComponent, Since}
3333
import org.apache.spark.ml.{linalg => newlinalg}
3434
import org.apache.spark.mllib.util.NumericParser
3535
import org.apache.spark.sql.catalyst.InternalRow
36-
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
37-
import org.apache.spark.sql.catalyst.util.{DoubleArrayData, IntArrayData}
36+
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData}
3837
import org.apache.spark.sql.types._
3938

4039
/**
@@ -216,15 +215,15 @@ class VectorUDT extends UserDefinedType[Vector] {
216215
val row = new GenericMutableRow(4)
217216
row.setByte(0, 0)
218217
row.setInt(1, size)
219-
row.update(2, new IntArrayData(indices))
220-
row.update(3, new DoubleArrayData(values))
218+
row.update(2, UnsafeArrayData.fromPrimitiveArray(indices))
219+
row.update(3, UnsafeArrayData.fromPrimitiveArray(values))
221220
row
222221
case DenseVector(values) =>
223222
val row = new GenericMutableRow(4)
224223
row.setByte(0, 1)
225224
row.setNullAt(1)
226225
row.setNullAt(2)
227-
row.update(3, new DoubleArrayData(values))
226+
row.update(3, UnsafeArrayData.fromPrimitiveArray(values))
228227
row
229228
}
230229
}
@@ -238,11 +237,20 @@ class VectorUDT extends UserDefinedType[Vector] {
238237
tpe match {
239238
case 0 =>
240239
val size = row.getInt(1)
241-
val indices = row.getArray(2).toIntArray()
242-
val values = row.getArray(3).toDoubleArray()
240+
val indices = row.getArray(2) match {
241+
case u: UnsafeArrayData => u.toPrimitiveIntArray
242+
case a => a.toIntArray()
243+
}
244+
val values = row.getArray(3) match {
245+
case u: UnsafeArrayData => u.toPrimitiveDoubleArray
246+
case a => a.toDoubleArray()
247+
}
243248
new SparseVector(size, indices, values)
244249
case 1 =>
245-
val values = row.getArray(3).toDoubleArray()
250+
val values = row.getArray(3) match {
251+
case u: UnsafeArrayData => u.toPrimitiveDoubleArray
252+
case a => a.toDoubleArray()
253+
}
246254
new DenseVector(values)
247255
}
248256
}

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ private void assertIndexIsValid(int ordinal) {
8181
}
8282

8383
public Object[] array() {
84-
throw new UnsupportedOperationException("Only supported on GenericArrayData.");
84+
throw new UnsupportedOperationException("Not supported on UnsafeArrayData.");
8585
}
8686

8787
/**
@@ -336,4 +336,62 @@ public UnsafeArrayData copy() {
336336
arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
337337
return arrayCopy;
338338
}
339+
340+
public int[] toPrimitiveIntArray() {
341+
int[] result = new int[numElements];
342+
Platform.copyMemory(baseObject, baseOffset + 4 + 4 * numElements,
343+
result, Platform.INT_ARRAY_OFFSET, 4 * numElements);
344+
return result;
345+
}
346+
347+
public double[] toPrimitiveDoubleArray() {
348+
double[] result = new double[numElements];
349+
Platform.copyMemory(baseObject, baseOffset + 4 + 4 * numElements,
350+
result, Platform.DOUBLE_ARRAY_OFFSET, 8 * numElements);
351+
return result;
352+
}
353+
354+
public static UnsafeArrayData fromPrimitiveArray(int[] arr) {
355+
int offsetRegionSize = 4 * arr.length;
356+
int valueRegionSize = 4 * arr.length;
357+
int totalSize = 4 + offsetRegionSize + valueRegionSize;
358+
byte[] data = new byte[totalSize];
359+
360+
Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length);
361+
362+
int elementOffsetStart = 4 + offsetRegionSize;
363+
for (int i = 0; i < arr.length; i++) {
364+
Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET + 4 + i * 4, elementOffsetStart + i * 4);
365+
}
366+
367+
Platform.copyMemory(arr, Platform.INT_ARRAY_OFFSET, data,
368+
Platform.BYTE_ARRAY_OFFSET + elementOffsetStart, valueRegionSize);
369+
370+
UnsafeArrayData result = new UnsafeArrayData();
371+
result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize);
372+
return result;
373+
}
374+
375+
public static UnsafeArrayData fromPrimitiveArray(double[] arr) {
376+
int offsetRegionSize = 4 * arr.length;
377+
int valueRegionSize = 8 * arr.length;
378+
int totalSize = 4 + offsetRegionSize + valueRegionSize;
379+
byte[] data = new byte[totalSize];
380+
381+
Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length);
382+
383+
int elementOffsetStart = 4 + offsetRegionSize;
384+
for (int i = 0; i < arr.length; i++) {
385+
Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET + 4 + i * 4, elementOffsetStart + i * 8);
386+
}
387+
388+
Platform.copyMemory(arr, Platform.DOUBLE_ARRAY_OFFSET, data,
389+
Platform.BYTE_ARRAY_OFFSET + elementOffsetStart, valueRegionSize);
390+
391+
UnsafeArrayData result = new UnsafeArrayData();
392+
result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize);
393+
return result;
394+
}
395+
396+
// TODO: add more specialized methods.
339397
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -137,61 +137,3 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData {
137137
result
138138
}
139139
}
140-
141-
abstract class SpecializedArrayData extends ArrayData {
142-
// Primitive arrays can't haven null elements.
143-
override def isNullAt(ordinal: Int): Boolean = false
144-
145-
private def fail() = {
146-
throw new UnsupportedOperationException(
147-
"Specialized array data should implement its corresponding get method")
148-
}
149-
150-
override def get(ordinal: Int, elementType: DataType): AnyRef = fail()
151-
override def getBoolean(ordinal: Int): Boolean = fail()
152-
override def getByte(ordinal: Int): Byte = fail()
153-
override def getShort(ordinal: Int): Short = fail()
154-
override def getInt(ordinal: Int): Int = fail()
155-
override def getLong(ordinal: Int): Long = fail()
156-
override def getFloat(ordinal: Int): Float = fail()
157-
override def getDouble(ordinal: Int): Double = fail()
158-
override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = fail()
159-
override def getUTF8String(ordinal: Int): UTF8String = fail()
160-
override def getBinary(ordinal: Int): Array[Byte] = fail()
161-
override def getInterval(ordinal: Int): CalendarInterval = fail()
162-
override def getStruct(ordinal: Int, numFields: Int): InternalRow = fail()
163-
override def getArray(ordinal: Int): ArrayData = fail()
164-
override def getMap(ordinal: Int): MapData = fail()
165-
}
166-
167-
class IntArrayData(val values: Array[Int]) extends SpecializedArrayData {
168-
169-
override def array(): Array[Any] = values.map(_.asInstanceOf[Any])
170-
171-
override def numElements(): Int = values.length
172-
173-
override def get(ordinal: Int, elementType: DataType): AnyRef =
174-
values(ordinal).asInstanceOf[AnyRef]
175-
176-
override def getInt(ordinal: Int): Int = values(ordinal)
177-
178-
override def toIntArray(): Array[Int] = values
179-
180-
override def copy(): IntArrayData = new IntArrayData(values.clone())
181-
}
182-
183-
class DoubleArrayData(val values: Array[Double]) extends SpecializedArrayData {
184-
185-
override def array(): Array[Any] = values.map(_.asInstanceOf[Any])
186-
187-
override def numElements(): Int = values.length
188-
189-
override def get(ordinal: Int, elementType: DataType): AnyRef =
190-
values(ordinal).asInstanceOf[AnyRef]
191-
192-
override def getDouble(ordinal: Int): Double = values(ordinal)
193-
194-
override def toDoubleArray(): Array[Double] = values
195-
196-
override def copy(): DoubleArrayData = new DoubleArrayData(values.clone())
197-
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.util
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
22+
23+
class UnsafeArraySuite extends SparkFunSuite {
24+
25+
test("from primitive int array") {
26+
val array = Array(1, 10, 100)
27+
val unsafe = UnsafeArrayData.fromPrimitiveArray(array)
28+
assert(unsafe.numElements == 3)
29+
assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 4 * 3)
30+
assert(unsafe.getInt(0) == 1)
31+
assert(unsafe.getInt(1) == 10)
32+
assert(unsafe.getInt(2) == 100)
33+
}
34+
35+
test("from primitive double array") {
36+
val array = Array(1.1, 2.2, 3.3)
37+
val unsafe = UnsafeArrayData.fromPrimitiveArray(array)
38+
assert(unsafe.numElements == 3)
39+
assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 8 * 3)
40+
assert(unsafe.getDouble(0) == 1.1)
41+
assert(unsafe.getDouble(1) == 2.2)
42+
assert(unsafe.getDouble(2) == 3.3)
43+
}
44+
45+
test("to primitive int array") {
46+
val array = Array(1, 10, 100)
47+
val unsafe = UnsafeArrayData.fromPrimitiveArray(array)
48+
val array2 = unsafe.toPrimitiveIntArray
49+
assert(array.toSeq == array2.toSeq)
50+
}
51+
52+
test("to primitive double array") {
53+
val array = Array(1.1, 2.2, 3.3)
54+
val unsafe = UnsafeArrayData.fromPrimitiveArray(array)
55+
val array2 = unsafe.toPrimitiveDoubleArray
56+
assert(array.toSeq == array2.toSeq)
57+
}
58+
}

0 commit comments

Comments
 (0)