Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor optional column operations #696

Merged
merged 1 commit into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 48 additions & 15 deletions dataset/src/main/scala/frameless/TypedColumn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package frameless

import frameless.functions.{litAggr, lit => flit}
import frameless.syntax._

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.{Column, FramelessInternals}

import shapeless._
import shapeless.ops.record.Selector

Expand All @@ -27,9 +29,8 @@ sealed class TypedColumn[T, U](expr: Expression)(

type ThisType[A, B] = TypedColumn[A, B]

def this(column: Column)(implicit uencoder: TypedEncoder[U]) = {
def this(column: Column)(implicit uencoder: TypedEncoder[U]) =
this(FramelessInternals.expr(column))
}

override def typed[W, U1: TypedEncoder](c: Column): TypedColumn[W, U1] = c.typedColumn

Expand Down Expand Up @@ -140,8 +141,9 @@ abstract class AbstractTypedColumn[T, U]
equalsTo(other)

/** Inequality test.
*
* {{{
* df.filter( df.col('a) =!= df.col('b) )
* df.filter(df.col('a) =!= df.col('b))
* }}}
*
* apache/spark
Expand All @@ -150,28 +152,28 @@ abstract class AbstractTypedColumn[T, U]
typed(Not(equalsTo(other).expr))

/** Inequality test.
*
* {{{
* df.filter( df.col('a) =!= "a" )
* df.filter(df.col('a) =!= "a")
* }}}
*
* apache/spark
*/
def =!=(u: U): ThisType[T, Boolean] =
typed(Not(equalsTo(lit(u)).expr))
def =!=(u: U): ThisType[T, Boolean] = typed(Not(equalsTo(lit(u)).expr))

/** True if the current expression is an Option and it's None.
*
* apache/spark
*/
def isNone(implicit i0: U <:< Option[_]): ThisType[T, Boolean] =
equalsTo[T, T](lit[U](None.asInstanceOf[U]))
typed(IsNull(expr))

/** True if the current expression is an Option and it's not None.
*
* apache/spark
*/
def isNotNone(implicit i0: U <:< Option[_]): ThisType[T, Boolean] =
typed(Not(equalsTo(lit(None.asInstanceOf[U])).expr))
typed(IsNotNull(expr))

/** True if the current expression is a fractional number and is not NaN.
*
Expand All @@ -180,15 +182,43 @@ abstract class AbstractTypedColumn[T, U]
def isNaN(implicit n: CatalystNaN[U]): ThisType[T, Boolean] =
typed(self.untyped.isNaN)

/** Convert an Optional column by providing a default value
/**
* True if the value for this optional column `exists` as expected
* (see `Option.exists`).
*
* {{{
* df.col('opt).isSome(_ === someOtherCol)
* }}}
*/
def isSome[V](exists: ThisType[T, V] => ThisType[T, Boolean])(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = someOr[V](exists, false)

/**
* True if the value for this optional column `exists` as expected,
* or is `None`. (see `Option.forall`).
*
* {{{
* df.col('opt).isSomeOrNone(_ === someOtherCol)
* }}}
*/
def isSomeOrNone[V](exists: ThisType[T, V] => ThisType[T, Boolean])(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = someOr[V](exists, true)

private def someOr[V](exists: ThisType[T, V] => ThisType[T, Boolean], default: Boolean)(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = {
val defaultExpr = if (default) Literal.TrueLiteral else Literal.FalseLiteral

typed(Coalesce(Seq(opt(i0).map(exists).expr, defaultExpr)))
}

/** Convert an Optional column by providing a default value.
*
* {{{
* df( df('opt).getOrElse(df('defaultValue)) )
* df(df('opt).getOrElse(df('defaultValue)))
* }}}
*/
def getOrElse[TT, W, Out](default: ThisType[TT, Out])(implicit i0: U =:= Option[Out], i1: With.Aux[T, TT, W]): ThisType[W, Out] =
typed(Coalesce(Seq(expr, default.expr)))(default.uencoder)

/** Convert an Optional column by providing a default value
/** Convert an Optional column by providing a default value.
*
* {{{
* df( df('opt).getOrElse(defaultConstant) )
* }}}
Expand All @@ -197,6 +227,7 @@ abstract class AbstractTypedColumn[T, U]
getOrElse(lit[Out](default))

/** Sum of this expression and another expression.
*
* {{{
* // The following selects the sum of a person's height and weight.
* people.select( people.col('height) plus people.col('weight) )
Expand Down Expand Up @@ -700,9 +731,10 @@ abstract class AbstractTypedColumn[T, U]
or(other)

/** Less than.
*
* {{{
* // The following selects people younger than the maxAge column.
* df.select( df('age) < df('maxAge) )
* // The following selects people younger than the maxAge column.
* df.select(df('age) < df('maxAge) )
* }}}
*
* @param other another column of the same type
Expand All @@ -712,9 +744,10 @@ abstract class AbstractTypedColumn[T, U]
typed(self.untyped < other.untyped)

/** Less than or equal to.
*
* {{{
* // The following selects people younger or equal than the maxAge column.
* df.select( df('age) <= df('maxAge)
* // The following selects people younger or equal than the maxAge column.
* df.select(df('age) <= df('maxAge)
* }}}
*
* @param other another column of the same type
Expand Down
1 change: 1 addition & 0 deletions dataset/src/main/scala/frameless/functions/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ package object functions extends Udf with UnaryFunctions {

if (ScalaReflection.isNativeType(encoder.jvmRepr) && encoder.catalystRepr == encoder.jvmRepr) {
val expr = Literal(value, encoder.catalystRepr)

new TypedColumn(expr)
} else {
val expr = new Literal(value, encoder.jvmRepr)
Expand Down
22 changes: 21 additions & 1 deletion dataset/src/test/scala/frameless/FilterTests.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package frameless

import org.scalatest.matchers.should.Matchers

import org.scalacheck.Prop
import org.scalacheck.Prop._

class FilterTests extends TypedDatasetSuite {
final class FilterTests extends TypedDatasetSuite with Matchers {
test("filter('a == lit(b))") {
def prop[A: TypedEncoder](elem: A, data: Vector[X1[A]])(implicit ex1: TypedEncoder[X1[A]]): Prop = {
val dataset = TypedDataset.create(data)
Expand Down Expand Up @@ -145,6 +147,24 @@ class FilterTests extends TypedDatasetSuite {
check(forAll(prop[Option[X1[X1[Vector[Option[Int]]]]]] _))
}

test("Option content filter") {
val data = (Option(1L), Option(2L)) :: (Option(0L), Option(1L)) :: (None, None) :: Nil

val ds = TypedDataset.create(data)

val l = functions.lit[Long, (Option[Long], Option[Long])](0L)
val exists = ds('_1).isSome[Long](_ <= l)
val forall = ds('_1).isSomeOrNone[Long](_ <= l)

ds.select(exists).collect().run() shouldEqual Seq(false, true, false)
ds.select(forall).collect().run() shouldEqual Seq(false, true, true)

ds.filter(exists).collect().run() shouldEqual Seq(Option(0L) -> Option(1L))

ds.filter(forall).collect().run() shouldEqual Seq(
Option(0L) -> Option(1L), (None -> None))
}

test("filter with isin values") {
def prop[A: TypedEncoder](data: Vector[X1[A]], values: Vector[A])(implicit a : CatalystIsin[A]): Prop = {
val ds = TypedDataset.create(data)
Expand Down