Skip to content

Commit

Permalink
Merge pull request #147 from JetBrains/scala-case-class-encoding
Browse files Browse the repository at this point in the history
Rewrote product encoding to support scala case classes
  • Loading branch information
Jolanrensen authored Apr 19, 2022
2 parents d62e3af + 2a409c6 commit aa6d3e5
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.expressions.{Expression, _}
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection, WalkedTypePath}
import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, InternalRow, ScalaReflection, WalkedTypePath}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.util.Utils
Expand All @@ -42,11 +42,12 @@ import java.lang.Exception
* for classes whose fields are entirely defined by constructor params but should not be
* case classes.
*/
trait DefinedByConstructorParams
//trait DefinedByConstructorParams

/**
* KotlinReflection is heavily inspired by ScalaReflection and even extends it just to add several methods
*/
//noinspection RedundantBlock
object KotlinReflection extends KotlinReflection {
/**
* Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping
Expand Down Expand Up @@ -916,9 +917,18 @@ object KotlinReflection extends KotlinReflection {
}
//</editor-fold>

case _ if predefinedDt.isDefined => {
// Kotlin specific cases
case t if predefinedDt.isDefined => {

// if (seenTypeSet.contains(t)) {
// throw new UnsupportedOperationException(
// s"cannot have circular references in class, but got the circular reference of class $t"
// )
// }

predefinedDt.get match {

// Kotlin data class
case dataType: KDataTypeWrapper => {
val cls = dataType.cls
val properties = getJavaBeanReadableProperties(cls)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package org.jetbrains.kotlinx.spark.extensions

case class DemoCaseClass[T](a: Int, b: T)
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,18 @@ fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
KDataTypeWrapper(structType, klass.java, true)
}
klass.isSubclassOf(Product::class) -> {
val params = type.arguments.mapIndexed { i, it ->
"_${i + 1}" to it.type!!

// create map from T1, T2 to Int, String etc.
val typeMap = klass.constructors.first().typeParameters.map { it.name }
.zip(
type.arguments.map { it.type }
)
.toMap()

// collect params by name and actual type
val params = klass.constructors.first().parameters.map {
val typeName = it.type.toString().replace("!", "")
it.name to (typeMap[typeName] ?: it.type)
}

val structType = DataTypes.createStructType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.CalendarInterval
import org.jetbrains.kotlinx.spark.api.tuples.*
import scala.Product
import scala.Tuple1
import scala.Tuple2
import scala.Tuple3
import org.jetbrains.kotlinx.spark.extensions.DemoCaseClass
import scala.*
import java.math.BigDecimal
import java.sql.Date
import java.sql.Timestamp
Expand Down Expand Up @@ -180,6 +178,88 @@ class EncodingTest : ShouldSpec({
context("schema") {
withSpark(props = mapOf("spark.sql.codegen.comments" to true)) {

should("handle Scala Case class datasets") {
val caseClasses = listOf(
DemoCaseClass(1, "1"),
DemoCaseClass(2, "2"),
DemoCaseClass(3, "3"),
)
val dataset = caseClasses.toDS()
dataset.show()
dataset.collectAsList() shouldBe caseClasses
}

should("handle Scala Case class with data class datasets") {
val caseClasses = listOf(
DemoCaseClass(1, "1" to 1L),
DemoCaseClass(2, "2" to 2L),
DemoCaseClass(3, "3" to 3L),
)
val dataset = caseClasses.toDS()
dataset.show()
dataset.collectAsList() shouldBe caseClasses
}

should("handle data class with Scala Case class datasets") {
val caseClasses = listOf(
1 to DemoCaseClass(1, "1"),
2 to DemoCaseClass(2, "2"),
3 to DemoCaseClass(3, "3"),
)
val dataset = caseClasses.toDS()
dataset.show()
dataset.collectAsList() shouldBe caseClasses
}

should("handle data class with Scala Case class & deeper datasets") {
val caseClasses = listOf(
1 to DemoCaseClass(1, "1" to DemoCaseClass(1, 1.0)),
2 to DemoCaseClass(2, "2" to DemoCaseClass(2, 2.0)),
3 to DemoCaseClass(3, "3" to DemoCaseClass(3, 3.0)),
)
val dataset = caseClasses.toDS()
dataset.show()
dataset.collectAsList() shouldBe caseClasses
}


xshould("handle Scala Option datasets") {
val caseClasses = listOf(Some(1), Some(2), Some(3))
val dataset = caseClasses.toDS()
dataset.show()
dataset.collectAsList() shouldBe caseClasses
}

xshould("handle Scala Option Option datasets") {
val caseClasses = listOf(
Some(Some(1)),
Some(Some(2)),
Some(Some(3)),
)
val dataset = caseClasses.toDS()
dataset.collectAsList() shouldBe caseClasses
}

xshould("handle data class Scala Option datasets") {
val caseClasses = listOf(
Some(1) to Some(2),
Some(3) to Some(4),
Some(5) to Some(6),
)
val dataset = caseClasses.toDS()
dataset.collectAsList() shouldBe caseClasses
}

xshould("handle Scala Option data class datasets") {
val caseClasses = listOf(
Some(1 to 2),
Some(3 to 4),
Some(5 to 6),
)
val dataset = caseClasses.toDS()
dataset.collectAsList() shouldBe caseClasses
}

should("collect data classes with doubles correctly") {
val ll1 = LonLat(1.0, 2.0)
val ll2 = LonLat(3.0, 4.0)
Expand Down

0 comments on commit aa6d3e5

Please sign in to comment.