Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -223,20 +223,18 @@ class ImputerModel private[ml] (

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
var outputDF = dataset
val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq

$(inputCols).zip($(outputCols)).zip(surrogates).foreach {
val newCols = $(inputCols).zip($(outputCols)).zip(surrogates).map {
case ((inputCol, outputCol), surrogate) =>
val inputType = dataset.schema(inputCol).dataType
val ic = col(inputCol)
outputDF = outputDF.withColumn(outputCol,
when(ic.isNull, surrogate)
when(ic.isNull, surrogate)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: indent

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This when is not a call of previous line. I think it doesn't need to indent?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I misread. The style is ok.

.when(ic === $(missingValue), surrogate)
.otherwise(ic)
.cast(inputType))
.cast(inputType)
}
outputDF.toDF()
dataset.withColumns($(outputCols), newCols).toDF()
}

override def transformSchema(schema: StructType): StructType = {
Expand Down
42 changes: 30 additions & 12 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2083,22 +2083,40 @@ class Dataset[T] private[sql](
* @group untypedrel
* @since 2.0.0
*/
def withColumn(colName: String, col: Column): DataFrame = {
def withColumn(colName: String, col: Column): DataFrame = withColumns(Seq(colName), Seq(col))

/**
* Returns a new Dataset by adding columns or replacing the existing columns that has
* the same names.
*/
private[spark] def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = {
require(colNames.size == cols.size,
s"The size of column names: ${colNames.size} isn't equal to " +
s"the size of columns: ${cols.size}")
SchemaUtils.checkColumnNameDuplication(
colNames,
"in given column names",
sparkSession.sessionState.conf.caseSensitiveAnalysis)

val resolver = sparkSession.sessionState.analyzer.resolver
val output = queryExecution.analyzed.output
val shouldReplace = output.exists(f => resolver(f.name, colName))
if (shouldReplace) {
val columns = output.map { field =>
if (resolver(field.name, colName)) {
col.as(colName)
} else {
Column(field)
}

val columnMap = colNames.zip(cols).toMap

val replacedAndExistingColumns = output.map { field =>
columnMap.find { case (colName, _) =>
resolver(field.name, colName)
} match {
case Some((colName: String, col: Column)) => col.as(colName)
case _ => Column(field)
}
select(columns : _*)
} else {
select(Column("*"), col.as(colName))
}

val newColumns = columnMap.filter { case (colName, col) =>
!output.exists(f => resolver(f.name, colName))
}.map { case (colName, col) => col.as(colName) }

select(replacedAndExistingColumns ++ newColumns : _*)
}

/**
Expand Down
52 changes: 52 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,49 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(df.schema.map(_.name) === Seq("key", "value", "newCol"))
}

test("withColumns") {
val df = testData.toDF().withColumns(Seq("newCol1", "newCol2"),
Seq(col("key") + 1, col("key") + 2))
checkAnswer(
df,
testData.collect().map { case Row(key: Int, value: String) =>
Row(key, value, key + 1, key + 2)
}.toSeq)
assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2"))

val err = intercept[IllegalArgumentException] {
testData.toDF().withColumns(Seq("newCol1"),
Seq(col("key") + 1, col("key") + 2))
}
assert(
err.getMessage.contains("The size of column names: 1 isn't equal to the size of columns: 2"))

val err2 = intercept[AnalysisException] {
testData.toDF().withColumns(Seq("newCol1", "newCOL1"),
Seq(col("key") + 1, col("key") + 2))
}
assert(err2.getMessage.contains("Found duplicate column(s)"))
}

test("withColumns: case sensitive") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
val df = testData.toDF().withColumns(Seq("newCol1", "newCOL1"),
Seq(col("key") + 1, col("key") + 2))
checkAnswer(
df,
testData.collect().map { case Row(key: Int, value: String) =>
Row(key, value, key + 1, key + 2)
}.toSeq)
assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCOL1"))

val err = intercept[AnalysisException] {
testData.toDF().withColumns(Seq("newCol1", "newCol1"),
Seq(col("key") + 1, col("key") + 2))
}
assert(err.getMessage.contains("Found duplicate column(s)"))
}
}

test("replace column using withColumn") {
val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
val df3 = df2.withColumn("x", df2("x") + 1)
Expand All @@ -649,6 +692,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row(2) :: Row(3) :: Row(4) :: Nil)
}

test("replace column using withColumns") {
val df2 = sparkContext.parallelize(Array((1, 2), (2, 3), (3, 4))).toDF("x", "y")
val df3 = df2.withColumns(Seq("x", "newCol1", "newCol2"),
Seq(df2("x") + 1, df2("y"), df2("y") + 1))
checkAnswer(
df3.select("x", "newCol1", "newCol2"),
Row(2, 2, 3) :: Row(3, 3, 4) :: Row(4, 4, 5) :: Nil)
}

test("drop column using drop") {
val df = testData.drop("key")
checkAnswer(
Expand Down