Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.UserDefinedFunction
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -90,7 +91,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
* account of the embedded param map. So the param values should be determined solely by the input
* param map.
*/
protected def createTransformFunc: IN => OUT
protected def transformFunc: UserDefinedFunction

/**
* Returns the data type of the output column.
Expand All @@ -115,8 +116,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]

override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
dataset.withColumn($(outputCol),
callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol))))
dataset.withColumn($(outputCol), this.transformFunc(col($(inputCol))))
}

override def copy(extra: ParamMap): T = defaultCopy(extra)
Expand Down
14 changes: 9 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.BooleanParam
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.UserDefinedFunction
import org.apache.spark.sql.types.DataType

/**
Expand Down Expand Up @@ -57,11 +59,13 @@ class DCT(override val uid: String)

setDefault(inverse -> false)

override protected def createTransformFunc: Vector => Vector = { vec =>
val result = vec.toArray
val jTransformer = new DoubleDCT_1D(result.length)
if ($(inverse)) jTransformer.inverse(result, true) else jTransformer.forward(result, true)
Vectors.dense(result)
override protected def transformFunc: UserDefinedFunction = {
udf { input: Vector =>
val result = input.toArray
val jTransformer = new DoubleDCT_1D(result.length)
if ($(inverse)) jTransformer.inverse(result, true) else jTransformer.forward(result, true)
Vectors.dense(result)
}
}

override protected def validateInputType(inputType: DataType): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.UserDefinedFunction
import org.apache.spark.sql.types.DataType

/**
Expand All @@ -49,10 +51,12 @@ class ElementwiseProduct(override val uid: String)
/** @group getParam */
def getScalingVec: Vector = getOrDefault(scalingVec)

override protected def createTransformFunc: Vector => Vector = {
override protected def transformFunc: UserDefinedFunction = {
require(params.contains(scalingVec), s"transformation requires a weight vector")
val elemScaler = new feature.ElementwiseProduct($(scalingVec))
elemScaler.transform
udf { input: Vector =>
val elemScaler = new feature.ElementwiseProduct($(scalingVec))
elemScaler.transform(input)
}
}

override protected def outputDataType: DataType = new VectorUDT()
Expand Down
8 changes: 6 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.UserDefinedFunction
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}

/**
Expand Down Expand Up @@ -56,8 +58,10 @@ class NGram(override val uid: String)

setDefault(n -> 2)

override protected def createTransformFunc: Seq[String] => Seq[String] = {
_.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq
override protected def transformFunc: UserDefinedFunction = {
udf { input: Seq[String] =>
input.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq
}
}

override protected def validateInputType(inputType: DataType): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import org.apache.spark.ml.param.{DoubleParam, ParamValidators}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.UserDefinedFunction
import org.apache.spark.sql.types.DataType

/**
Expand Down Expand Up @@ -50,9 +52,11 @@ class Normalizer(override val uid: String)
/** @group setParam */
def setP(value: Double): this.type = set(p, value)

override protected def createTransformFunc: Vector => Vector = {
val normalizer = new feature.Normalizer($(p))
normalizer.transform
override protected def transformFunc: UserDefinedFunction = {
udf { input: Vector =>
val normalizer = new feature.Normalizer($(p))
normalizer.transform(input)
}
}

override protected def outputDataType: DataType = new VectorUDT()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.UserDefinedFunction

/**
* :: Experimental ::
Expand Down Expand Up @@ -56,8 +58,10 @@ class PolynomialExpansion(override val uid: String)
/** @group setParam */
def setDegree(value: Int): this.type = set(degree, value)

override protected def createTransformFunc: Vector => Vector = { v =>
PolynomialExpansion.expand(v, $(degree))
override protected def transformFunc: UserDefinedFunction = {
udf { input: Vector =>
PolynomialExpansion.expand(input, $(degree))
}
}

override protected def outputDataType: DataType = new VectorUDT()
Expand Down
22 changes: 14 additions & 8 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.UserDefinedFunction
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}

/**
Expand All @@ -35,8 +37,10 @@ class Tokenizer(override val uid: String)

def this() = this(Identifiable.randomUID("tok"))

override protected def createTransformFunc: String => Seq[String] = {
_.toLowerCase.split("\\s")
override protected def transformFunc: UserDefinedFunction = {
udf { input: String =>
input.toLowerCase.split("\\s")
}
}

override protected def validateInputType(inputType: DataType): Unit = {
Expand Down Expand Up @@ -124,12 +128,14 @@ class RegexTokenizer(override val uid: String)

setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+", toLowercase -> true)

override protected def createTransformFunc: String => Seq[String] = { originStr =>
val re = $(pattern).r
val str = if ($(toLowercase)) originStr.toLowerCase() else originStr
val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq
val minLength = $(minTokenLength)
tokens.filter(_.length >= minLength)
override protected def transformFunc: UserDefinedFunction = {
udf { input: String =>
val re = $(pattern).r
val str = if ($(toLowercase)) input.toLowerCase() else input
val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq
val minLength = $(minTokenLength)
tokens.filter(_.length >= minLength)
}
}

override protected def validateInputType(inputType: DataType): Unit = {
Expand Down