Skip to content

Commit

Permalink
Refactor optional column operations (#696)
Browse files Browse the repository at this point in the history
  • Loading branch information
cchantep authored Mar 16, 2023
1 parent a27b04b commit 65fad0a
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 16 deletions.
63 changes: 48 additions & 15 deletions dataset/src/main/scala/frameless/TypedColumn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package frameless

import frameless.functions.{litAggr, lit => flit}
import frameless.syntax._

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.{Column, FramelessInternals}

import shapeless._
import shapeless.ops.record.Selector

Expand All @@ -27,9 +29,8 @@ sealed class TypedColumn[T, U](expr: Expression)(

type ThisType[A, B] = TypedColumn[A, B]

def this(column: Column)(implicit uencoder: TypedEncoder[U]) = {
def this(column: Column)(implicit uencoder: TypedEncoder[U]) =
this(FramelessInternals.expr(column))
}

override def typed[W, U1: TypedEncoder](c: Column): TypedColumn[W, U1] = c.typedColumn

Expand Down Expand Up @@ -140,8 +141,9 @@ abstract class AbstractTypedColumn[T, U]
equalsTo(other)

/** Inequality test.
*
* {{{
* df.filter( df.col('a) =!= df.col('b) )
* df.filter(df.col('a) =!= df.col('b))
* }}}
*
* apache/spark
Expand All @@ -150,28 +152,28 @@ abstract class AbstractTypedColumn[T, U]
typed(Not(equalsTo(other).expr))

/** Inequality test.
*
* {{{
* df.filter( df.col('a) =!= "a" )
* df.filter(df.col('a) =!= "a")
* }}}
*
* apache/spark
*/
def =!=(u: U): ThisType[T, Boolean] =
typed(Not(equalsTo(lit(u)).expr))
def =!=(u: U): ThisType[T, Boolean] = typed(Not(equalsTo(lit(u)).expr))

/** True if the current expression is an Option and it's None.
*
* apache/spark
*/
def isNone(implicit i0: U <:< Option[_]): ThisType[T, Boolean] =
equalsTo[T, T](lit[U](None.asInstanceOf[U]))
typed(IsNull(expr))

/** True if the current expression is an Option and it's not None.
*
* apache/spark
*/
def isNotNone(implicit i0: U <:< Option[_]): ThisType[T, Boolean] =
typed(Not(equalsTo(lit(None.asInstanceOf[U])).expr))
typed(IsNotNull(expr))

/** True if the current expression is a fractional number and is not NaN.
*
Expand All @@ -180,15 +182,43 @@ abstract class AbstractTypedColumn[T, U]
def isNaN(implicit n: CatalystNaN[U]): ThisType[T, Boolean] =
typed(self.untyped.isNaN)

/** Convert an Optional column by providing a default value
/**
* True if the value for this optional column `exists` as expected
* (see `Option.exists`).
*
* {{{
* df.col('opt).isSome(_ === someOtherCol)
* }}}
*/
def isSome[V](exists: ThisType[T, V] => ThisType[T, Boolean])(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = someOr[V](exists, false)

/**
* True if the value for this optional column `exists` as expected,
* or is `None`. (see `Option.forall`).
*
* {{{
* df.col('opt).isSomeOrNone(_ === someOtherCol)
* }}}
*/
def isSomeOrNone[V](exists: ThisType[T, V] => ThisType[T, Boolean])(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = someOr[V](exists, true)

private def someOr[V](exists: ThisType[T, V] => ThisType[T, Boolean], default: Boolean)(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = {
val defaultExpr = if (default) Literal.TrueLiteral else Literal.FalseLiteral

typed(Coalesce(Seq(opt(i0).map(exists).expr, defaultExpr)))
}

/** Convert an Optional column by providing a default value.
*
* {{{
* df( df('opt).getOrElse(df('defaultValue)) )
* df(df('opt).getOrElse(df('defaultValue)))
* }}}
*/
def getOrElse[TT, W, Out](default: ThisType[TT, Out])(implicit i0: U =:= Option[Out], i1: With.Aux[T, TT, W]): ThisType[W, Out] =
typed(Coalesce(Seq(expr, default.expr)))(default.uencoder)

/** Convert an Optional column by providing a default value
/** Convert an Optional column by providing a default value.
*
* {{{
* df( df('opt).getOrElse(defaultConstant) )
* }}}
Expand All @@ -197,6 +227,7 @@ abstract class AbstractTypedColumn[T, U]
getOrElse(lit[Out](default))

/** Sum of this expression and another expression.
*
* {{{
* // The following selects the sum of a person's height and weight.
* people.select( people.col('height) plus people.col('weight) )
Expand Down Expand Up @@ -700,9 +731,10 @@ abstract class AbstractTypedColumn[T, U]
or(other)

/** Less than.
*
* {{{
* // The following selects people younger than the maxAge column.
* df.select( df('age) < df('maxAge) )
* // The following selects people younger than the maxAge column.
* df.select(df('age) < df('maxAge) )
* }}}
*
* @param other another column of the same type
Expand All @@ -712,9 +744,10 @@ abstract class AbstractTypedColumn[T, U]
typed(self.untyped < other.untyped)

/** Less than or equal to.
*
* {{{
* // The following selects people younger or equal than the maxAge column.
* df.select( df('age) <= df('maxAge)
* // The following selects people younger or equal than the maxAge column.
* df.select(df('age) <= df('maxAge)
* }}}
*
* @param other another column of the same type
Expand Down
1 change: 1 addition & 0 deletions dataset/src/main/scala/frameless/functions/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ package object functions extends Udf with UnaryFunctions {

if (ScalaReflection.isNativeType(encoder.jvmRepr) && encoder.catalystRepr == encoder.jvmRepr) {
val expr = Literal(value, encoder.catalystRepr)

new TypedColumn(expr)
} else {
val expr = new Literal(value, encoder.jvmRepr)
Expand Down
22 changes: 21 additions & 1 deletion dataset/src/test/scala/frameless/FilterTests.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package frameless

import org.scalatest.matchers.should.Matchers

import org.scalacheck.Prop
import org.scalacheck.Prop._

class FilterTests extends TypedDatasetSuite {
final class FilterTests extends TypedDatasetSuite with Matchers {
test("filter('a == lit(b))") {
def prop[A: TypedEncoder](elem: A, data: Vector[X1[A]])(implicit ex1: TypedEncoder[X1[A]]): Prop = {
val dataset = TypedDataset.create(data)
Expand Down Expand Up @@ -145,6 +147,24 @@ class FilterTests extends TypedDatasetSuite {
check(forAll(prop[Option[X1[X1[Vector[Option[Int]]]]]] _))
}

test("Option content filter") {
val data = (Option(1L), Option(2L)) :: (Option(0L), Option(1L)) :: (None, None) :: Nil

val ds = TypedDataset.create(data)

val l = functions.lit[Long, (Option[Long], Option[Long])](0L)
val exists = ds('_1).isSome[Long](_ <= l)
val forall = ds('_1).isSomeOrNone[Long](_ <= l)

ds.select(exists).collect().run() shouldEqual Seq(false, true, false)
ds.select(forall).collect().run() shouldEqual Seq(false, true, true)

ds.filter(exists).collect().run() shouldEqual Seq(Option(0L) -> Option(1L))

ds.filter(forall).collect().run() shouldEqual Seq(
Option(0L) -> Option(1L), (None -> None))
}

test("filter with isin values") {
def prop[A: TypedEncoder](data: Vector[X1[A]], values: Vector[A])(implicit a : CatalystIsin[A]): Prop = {
val ds = TypedDataset.create(data)
Expand Down

0 comments on commit 65fad0a

Please sign in to comment.