diff --git a/dataset/src/main/scala/frameless/FramelessSyntax.scala b/dataset/src/main/scala/frameless/FramelessSyntax.scala index fa102200e..5ba294921 100644 --- a/dataset/src/main/scala/frameless/FramelessSyntax.scala +++ b/dataset/src/main/scala/frameless/FramelessSyntax.scala @@ -4,7 +4,8 @@ import org.apache.spark.sql.{Column, DataFrame, Dataset} trait FramelessSyntax { implicit class ColumnSyntax(self: Column) { - def typed[T, U: TypedEncoder]: TypedColumn[T, U] = new TypedColumn[T, U](self) + def typedColumn[T, U: TypedEncoder]: TypedColumn[T, U] = new TypedColumn[T, U](self) + def typedAggregate[T, U: TypedEncoder]: TypedAggregate[T, U] = new TypedAggregate[T, U](self) } implicit class DatasetSyntax[T: TypedEncoder](self: Dataset[T]) { diff --git a/dataset/src/main/scala/frameless/TypedColumn.scala b/dataset/src/main/scala/frameless/TypedColumn.scala index b955305ca..c494a0c26 100644 --- a/dataset/src/main/scala/frameless/TypedColumn.scala +++ b/dataset/src/main/scala/frameless/TypedColumn.scala @@ -1,16 +1,15 @@ package frameless +import frameless.functions.{lit => flit, litAggr} import frameless.syntax._ -import frameless.functions._ - import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.{Column, FramelessInternals} import org.apache.spark.sql.types.DecimalType -import shapeless.ops.record.Selector +import org.apache.spark.sql.{Column, FramelessInternals} import shapeless._ +import shapeless.ops.record.Selector -import scala.reflect.ClassTag import scala.annotation.implicitNotFound +import scala.reflect.ClassTag sealed trait UntypedExpression[T] { def expr: Expression @@ -19,6 +18,39 @@ sealed trait UntypedExpression[T] { } /** Expression used in `select`-like constructions. + */ +sealed class TypedColumn[T, U](expr: Expression)( + implicit val uenc: TypedEncoder[U] +) extends AbstractTypedColumn[T, U](expr) { + + type ThisType[A, B] = TypedColumn[A, B] + + def this(column: Column)(implicit uencoder: TypedEncoder[U]) { + this(FramelessInternals.expr(column)) + } + + override def typed[U1: TypedEncoder](c: Column): TypedColumn[T, U1] = c.typedColumn + override def lit[U1: TypedEncoder](c: U1): TypedColumn[T,U1] = flit(c) +} + +/** Expression used in `agg`-like constructions. + */ +sealed class TypedAggregate[T, U](expr: Expression)( + implicit val uenc: TypedEncoder[U] +) extends AbstractTypedColumn[T, U](expr) { + + type ThisType[A, B] = TypedAggregate[A, B] + + def this(column: Column)(implicit uencoder: TypedEncoder[U]) { + this(FramelessInternals.expr(column)) + } + + override def typed[U1: TypedEncoder](c: Column): TypedAggregate[T,U1] = c.typedAggregate + override def lit[U1: TypedEncoder](c: U1): TypedAggregate[T,U1] = litAggr(c) +} + +/** Generic representation of a typed column. A typed column can either be a [[TypedAggregate]] or + * a [[frameless.TypedColumn]]. * * Documentation marked "apache/spark" is thanks to apache/spark Contributors * at https://github.com/apache/spark, licensed under Apache v2.0 available at @@ -27,31 +59,33 @@ sealed trait UntypedExpression[T] { * @tparam T type of dataset * @tparam U type of column */ -sealed class TypedColumn[T, U]( - val expr: Expression)( - implicit - val uencoder: TypedEncoder[U] -) extends UntypedExpression[T] { self => +abstract class AbstractTypedColumn[T, U] + (val expr: Expression) + (implicit val uencoder: TypedEncoder[U]) + extends UntypedExpression[T] { self => - /** From an untyped Column to a [[TypedColumn]] - * - * @param column a spark.sql Column - * @param uencoder encoder of the resulting type U - */ - def this(column: Column)(implicit uencoder: TypedEncoder[U]) { - this(FramelessInternals.expr(column)) - } + type ThisType[A, B] <: AbstractTypedColumn[A, B] /** Fall back to an untyped Column */ def untyped: Column = new Column(expr) - private def withExpr(newExpr: Expression): Column = new Column(newExpr) - - private def equalsTo(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = withExpr { + private def equalsTo(other: ThisType[T, U]): ThisType[T, Boolean] = typed { if (uencoder.nullable && uencoder.catalystRepr.typeName != "struct") EqualNullSafe(self.expr, other.expr) else EqualTo(self.expr, other.expr) - }.typed + } + + /** Creates a typed column of either TypedColumn or TypedAggregate from an expression. + */ + protected def typed[U1: TypedEncoder](e: Expression): ThisType[T, U1] = typed(new Column(e)) + + /** Creates a typed column of either TypedColumn or TypedAggregate. + */ + def typed[U1: TypedEncoder](c: Column): ThisType[T, U1] + + /** Creates a typed column of either TypedColumn or TypedAggregate. + */ + def lit[U1: TypedEncoder](c: U1): ThisType[T, U1] /** Equality test. * {{{ @@ -60,7 +94,7 @@ sealed class TypedColumn[T, U]( * * apache/spark */ - def ===(other: U): TypedColumn[T, Boolean] = equalsTo(lit(other)) + def ===(other: U): ThisType[T, Boolean] = equalsTo(lit(other)) /** Equality test. * {{{ @@ -69,7 +103,7 @@ sealed class TypedColumn[T, U]( * * apache/spark */ - def ===(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = equalsTo(other) + def ===(other: ThisType[T, U]): ThisType[T, Boolean] = equalsTo(other) /** Inequality test. * {{{ @@ -78,9 +112,7 @@ sealed class TypedColumn[T, U]( * * apache/spark */ - def =!=(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = withExpr { - Not(equalsTo(other).expr) - }.typed + def =!=(other: ThisType[T, U]): ThisType[T, Boolean] = typed(Not(equalsTo(other).expr)) /** Inequality test. * {{{ @@ -89,41 +121,37 @@ sealed class TypedColumn[T, U]( * * apache/spark */ - def =!=(other: U): TypedColumn[T, Boolean] = withExpr { - Not(equalsTo(lit(other)).expr) - }.typed + def =!=(other: U): ThisType[T, Boolean] = typed(Not(equalsTo(lit(other)).expr)) /** True if the current expression is an Option and it's None. * * apache/spark */ - def isNone(implicit isOption: U <:< Option[_]): TypedColumn[T, Boolean] = - equalsTo(lit[U,T](None.asInstanceOf[U])) + def isNone(implicit isOption: U <:< Option[_]): ThisType[T, Boolean] = + equalsTo(lit[U](None.asInstanceOf[U])) /** True if the current expression is an Option and it's not None. * * apache/spark */ - def isNotNone(implicit isOption: U <:< Option[_]): TypedColumn[T, Boolean] = withExpr { - Not(equalsTo(lit(None.asInstanceOf[U])).expr) - }.typed + def isNotNone(implicit isOption: U <:< Option[_]): ThisType[T, Boolean] = + typed(Not(equalsTo(lit(None.asInstanceOf[U])).expr)) /** Convert an Optional column by providing a default value * {{{ * df( df('opt).getOrElse(df('defaultValue)) ) * }}} */ - def getOrElse[Out](default: TypedColumn[T, Out])(implicit isOption: U =:= Option[Out]): TypedColumn[T, Out] = withExpr { - Coalesce(Seq(expr, default.expr)) - }.typed(default.uencoder) + def getOrElse[Out](default: ThisType[T, Out])(implicit isOption: U =:= Option[Out]): ThisType[T, Out] = + typed(Coalesce(Seq(expr, default.expr)))(default.uencoder) /** Convert an Optional column by providing a default value * {{{ * df( df('opt).getOrElse(defaultConstant) ) * }}} */ - def getOrElse[Out: TypedEncoder](default: Out)(implicit isOption: U =:= Option[Out]): TypedColumn[T, Out] = - getOrElse(lit[Out, T](default)) + def getOrElse[Out: TypedEncoder](default: Out)(implicit isOption: U =:= Option[Out]): ThisType[T, Out] = + getOrElse(lit[Out](default)) /** Sum of this expression and another expression. * {{{ @@ -133,8 +161,8 @@ sealed class TypedColumn[T, U]( * * apache/spark */ - def plus(other: TypedColumn[T, U])(implicit n: CatalystNumeric[U]): TypedColumn[T, U] = - self.untyped.plus(other.untyped).typed + def plus(other: ThisType[T, U])(implicit n: CatalystNumeric[U]): ThisType[T, U] = + typed(self.untyped.plus(other.untyped)) /** Sum of this expression and another expression. * {{{ @@ -144,7 +172,7 @@ sealed class TypedColumn[T, U]( * * apache/spark */ - def +(u: TypedColumn[T, U])(implicit n: CatalystNumeric[U]): TypedColumn[T, U] = plus(u) + def +(u: ThisType[T, U])(implicit n: CatalystNumeric[U]): ThisType[T, U] = plus(u) /** Sum of this expression (column) with a constant. * {{{ @@ -155,7 +183,7 @@ sealed class TypedColumn[T, U]( * @param u a constant of the same type * apache/spark */ - def +(u: U)(implicit n: CatalystNumeric[U]): TypedColumn[T, U] = self.untyped.plus(u).typed + def +(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, U] = typed(self.untyped.plus(u)) /** Unary minus, i.e. negate the expression. * {{{ @@ -165,7 +193,7 @@ sealed class TypedColumn[T, U]( * * apache/spark */ - def unary_-(implicit n: CatalystNumeric[U]): TypedColumn[T, U] = (-self.untyped).typed + def unary_-(implicit n: CatalystNumeric[U]): ThisType[T, U] = typed(-self.untyped) /** Subtraction. Subtract the other expression from this expression. * {{{ @@ -175,8 +203,7 @@ sealed class TypedColumn[T, U]( * * apache/spark */ - def minus(u: TypedColumn[T, U])(implicit n: CatalystNumeric[U]): TypedColumn[T, U] = - self.untyped.minus(u.untyped).typed + def minus(u: ThisType[T, U])(implicit n: CatalystNumeric[U]): ThisType[T, U] = typed(self.untyped.minus(u.untyped)) /** Subtraction. Subtract the other expression from this expression. * {{{ @@ -186,7 +213,7 @@ sealed class TypedColumn[T, U]( * * apache/spark */ - def -(u: TypedColumn[T, U])(implicit n: CatalystNumeric[U]): TypedColumn[T, U] = minus(u) + def -(u: ThisType[T, U])(implicit n: CatalystNumeric[U]): ThisType[T, U] = minus(u) /** Subtraction. Subtract the other expression from this expression. * {{{ @@ -197,7 +224,7 @@ sealed class TypedColumn[T, U]( * @param u a constant of the same type * apache/spark */ - def -(u: U)(implicit n: CatalystNumeric[U]): TypedColumn[T, U] = self.untyped.minus(u).typed + def -(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, U] = typed(self.untyped.minus(u)) /** Multiplication of this expression and another expression. * {{{ @@ -207,14 +234,14 @@ sealed class TypedColumn[T, U]( * * apache/spark */ - def multiply(u: TypedColumn[T, U])(implicit n: CatalystNumeric[U], ct: ClassTag[U]): TypedColumn[T, U] = { + def multiply(u: ThisType[T, U])(implicit n: CatalystNumeric[U], ct: ClassTag[U]): ThisType[T, U] = typed { if (ct.runtimeClass == BigDecimal(0).getClass) { // That's apparently the only way to get sound multiplication. // See https://issues.apache.org/jira/browse/SPARK-22036 val dt = DecimalType(20, 14) - self.untyped.cast(dt).multiply(u.untyped.cast(dt)).typed + self.untyped.cast(dt).multiply(u.untyped.cast(dt)) } else { - self.untyped.multiply(u.untyped).typed + self.untyped.multiply(u.untyped) } } @@ -226,7 +253,7 @@ sealed class TypedColumn[T, U]( * * apache/spark */ - def *(u: TypedColumn[T, U])(implicit n: CatalystNumeric[U], tt: ClassTag[U]): TypedColumn[T, U] = multiply(u) + def *(u: ThisType[T, U])(implicit n: CatalystNumeric[U], tt: ClassTag[U]): ThisType[T, U] = multiply(u) /** Multiplication of this expression a constant. * {{{ @@ -236,7 +263,7 @@ sealed class TypedColumn[T, U]( * * apache/spark */ - def *(u: U)(implicit n: CatalystNumeric[U]): TypedColumn[T, U] = self.untyped.multiply(u).typed + def *(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, U] = typed(self.untyped.multiply(u)) /** * Division this expression by another expression. @@ -248,8 +275,8 @@ sealed class TypedColumn[T, U]( * @param other another column of the same type * apache/spark */ - def divide[Out: TypedEncoder](other: TypedColumn[T, U])(implicit n: CatalystDivisible[U, Out]): TypedColumn[T, Out] = - self.untyped.divide(other.untyped).typed + def divide[Out: TypedEncoder](other: ThisType[T, U])(implicit n: CatalystDivisible[U, Out]): ThisType[T, Out] = + typed(self.untyped.divide(other.untyped)) /** * Division this expression by another expression. @@ -261,10 +288,10 @@ sealed class TypedColumn[T, U]( * @param other another column of the same type * apache/spark */ - def /[Out](other: TypedColumn[T, U]) + def /[Out](other: ThisType[T, U]) (implicit n: CatalystDivisible[U, Out], - e: TypedEncoder[Out]): TypedColumn[T, Out] = divide(other) + e: TypedEncoder[Out]): ThisType[T, Out] = divide(other) /** * Division this expression by another expression. @@ -276,7 +303,7 @@ sealed class TypedColumn[T, U]( * @param u a constant of the same type * apache/spark */ - def /(u: U)(implicit n: CatalystNumeric[U]): TypedColumn[T, Double] = self.untyped.divide(u).typed + def /(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, Double] = typed(self.untyped.divide(u)) /** * Bitwise AND this expression and another expression. @@ -287,7 +314,8 @@ sealed class TypedColumn[T, U]( * @param u a constant of the same type * apache/spark */ - def bitwiseAND(u: U)(implicit n: CatalystBitwise[U]): TypedColumn[T, U] = self.untyped.bitwiseAND(u).typed + def bitwiseAND(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = + typed(self.untyped.bitwiseAND(u)) /** * Bitwise AND this expression and another expression. @@ -298,8 +326,8 @@ sealed class TypedColumn[T, U]( * @param u a constant of the same type * apache/spark */ - def bitwiseAND(u: TypedColumn[T, U])(implicit n: CatalystBitwise[U]): TypedColumn[T, U] = - self.untyped.bitwiseAND(u.untyped).typed + def bitwiseAND(u: ThisType[T, U])(implicit n: CatalystBitwise[U]): ThisType[T, U] = + typed(self.untyped.bitwiseAND(u.untyped)) /** * Bitwise AND this expression and another expression (of same type). @@ -310,7 +338,7 @@ sealed class TypedColumn[T, U]( * @param u a constant of the same type * apache/spark */ - def &(u: U)(implicit n: CatalystBitwise[U]): TypedColumn[T, U] = bitwiseAND(u) + def &(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = bitwiseAND(u) /** * Bitwise AND this expression and another expression. @@ -321,7 +349,7 @@ sealed class TypedColumn[T, U]( * @param u a constant of the same type * apache/spark */ - def &(u: TypedColumn[T, U])(implicit n: CatalystBitwise[U]): TypedColumn[T, U] = bitwiseAND(u) + def &(u: ThisType[T, U])(implicit n: CatalystBitwise[U]): ThisType[T, U] = bitwiseAND(u) /** * Bitwise OR this expression and another expression. @@ -332,7 +360,7 @@ sealed class TypedColumn[T, U]( * @param u a constant of the same type * apache/spark */ - def bitwiseOR(u: U)(implicit n: CatalystBitwise[U]): TypedColumn[T, U] = self.untyped.bitwiseOR(u).typed + def bitwiseOR(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = typed(self.untyped.bitwiseOR(u)) /** * Bitwise OR this expression and another expression. @@ -343,8 +371,8 @@ sealed class TypedColumn[T, U]( * @param u a constant of the same type * apache/spark */ - def bitwiseOR(u: TypedColumn[T, U])(implicit n: CatalystBitwise[U]): TypedColumn[T, U] = - self.untyped.bitwiseOR(u.untyped).typed + def bitwiseOR(u: ThisType[T, U])(implicit n: CatalystBitwise[U]): ThisType[T, U] = + typed(self.untyped.bitwiseOR(u.untyped)) /** * Bitwise OR this expression and another expression (of same type). @@ -355,7 +383,7 @@ sealed class TypedColumn[T, U]( * @param u a constant of the same type * apache/spark */ - def |(u: U)(implicit n: CatalystBitwise[U]): TypedColumn[T, U] = bitwiseOR(u) + def |(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = bitwiseOR(u) /** * Bitwise OR this expression and another expression. @@ -366,7 +394,7 @@ sealed class TypedColumn[T, U]( * @param u a constant of the same type * apache/spark */ - def |(u: TypedColumn[T, U])(implicit n: CatalystBitwise[U]): TypedColumn[T, U] = bitwiseOR(u) + def |(u: ThisType[T, U])(implicit n: CatalystBitwise[U]): ThisType[T, U] = bitwiseOR(u) /** * Bitwise XOR this expression and another expression. @@ -377,7 +405,8 @@ sealed class TypedColumn[T, U]( * @param u a constant of the same type * apache/spark */ - def bitwiseXOR(u: U)(implicit n: CatalystBitwise[U]): TypedColumn[T, U] = self.untyped.bitwiseXOR(u).typed + def bitwiseXOR(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = + typed(self.untyped.bitwiseXOR(u)) /** * Bitwise XOR this expression and another expression. @@ -388,8 +417,8 @@ sealed class TypedColumn[T, U]( * @param u a constant of the same type * apache/spark */ - def bitwiseXOR(u: TypedColumn[T, U])(implicit n: CatalystBitwise[U]): TypedColumn[T, U] = - self.untyped.bitwiseXOR(u.untyped).typed + def bitwiseXOR(u: ThisType[T, U])(implicit n: CatalystBitwise[U]): ThisType[T, U] = + typed(self.untyped.bitwiseXOR(u.untyped)) /** * Bitwise XOR this expression and another expression (of same type). @@ -400,7 +429,7 @@ sealed class TypedColumn[T, U]( * @param u a constant of the same type * apache/spark */ - def ^(u: U)(implicit n: CatalystBitwise[U]): TypedColumn[T, U] = bitwiseXOR(u) + def ^(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = bitwiseXOR(u) /** * Bitwise XOR this expression and another expression. @@ -411,80 +440,168 @@ sealed class TypedColumn[T, U]( * @param u a constant of the same type * apache/spark */ - def ^(u: TypedColumn[T, U])(implicit n: CatalystBitwise[U]): TypedColumn[T, U] = bitwiseXOR(u) + def ^(u: ThisType[T, U])(implicit n: CatalystBitwise[U]): ThisType[T, U] = bitwiseXOR(u) /** Casts the column to a different type. * {{{ * df.select(df('a).cast[Int]) * }}} */ - def cast[A: TypedEncoder](implicit c: CatalystCast[U, A]): TypedColumn[T, A] = - self.untyped.cast(TypedEncoder[A].catalystRepr).typed + def cast[A: TypedEncoder](implicit c: CatalystCast[U, A]): ThisType[T, A] = + typed(self.untyped.cast(TypedEncoder[A].catalystRepr)) /** Contains test. * {{{ * df.filter ( df.col('a).contains("foo") ) * }}} */ - def contains(other: String)(implicit ev: U =:= String): TypedColumn[T, Boolean] = - self.untyped.contains(other).typed + def contains(other: String)(implicit ev: U =:= String): ThisType[T, Boolean] = + typed(self.untyped.contains(other)) /** Contains test. * {{{ * df.filter ( df.col('a).contains(df.col('b) ) * }}} */ - def contains(other: TypedColumn[T, U])(implicit ev: U =:= String): TypedColumn[T, Boolean] = - self.untyped.contains(other.untyped).typed + def contains(other: ThisType[T, U])(implicit ev: U =:= String): ThisType[T, Boolean] = + typed(self.untyped.contains(other.untyped)) /** Boolean AND. * {{{ * df.filter ( (df.col('a) === 1).and(df.col('b) > 5) ) * }}} */ - def and(other: TypedColumn[T, Boolean]): TypedColumn[T, Boolean] = - self.untyped.and(other.untyped).typed + def and(other: ThisType[T, Boolean]): ThisType[T, Boolean] = + typed(self.untyped.and(other.untyped)) /** Boolean AND. * {{{ * df.filter ( df.col('a) === 1 && df.col('b) > 5) * }}} */ - def && (other: TypedColumn[T, Boolean]): TypedColumn[T, Boolean] = - and(other) + def && (other: ThisType[T, Boolean]): ThisType[T, Boolean] = and(other) /** Boolean OR. * {{{ * df.filter ( (df.col('a) === 1).or(df.col('b) > 5) ) * }}} */ - def or(other: TypedColumn[T, Boolean]): TypedColumn[T, Boolean] = - self.untyped.or(other.untyped).typed + def or(other: ThisType[T, Boolean]): ThisType[T, Boolean] = + typed(self.untyped.or(other.untyped)) /** Boolean OR. * {{{ * df.filter ( df.col('a) === 1 || df.col('b) > 5) * }}} */ - def || (other: TypedColumn[T, Boolean]): TypedColumn[T, Boolean] = - or(other) -} + def || (other: ThisType[T, Boolean]): ThisType[T, Boolean] = or(other) -/** Expression used in `groupBy`-like constructions. - * - * @tparam T type of dataset - * @tparam U type of column for `groupBy` - */ -sealed class TypedAggregate[T, U](val expr: Expression)( - implicit - val uencoder: TypedEncoder[U] -) extends UntypedExpression[T] { + /** + * Less than. + * {{{ + * // The following selects people younger than the maxAge column. + * df.select( df('age) < df('maxAge) ) + * }}} + * + * @param u another column of the same type + * apache/spark + */ + def <(u: ThisType[T, U])(implicit canOrder: CatalystOrdered[U]): ThisType[T, Boolean] = + typed(self.untyped < u.untyped) - def this(column: Column)(implicit e: TypedEncoder[U]) { - this(FramelessInternals.expr(column)) - } + /** + * Less than or equal to. + * {{{ + * // The following selects people younger or equal than the maxAge column. + * df.select( df('age) <= df('maxAge) + * }}} + * + * @param u another column of the same type + * apache/spark + */ + def <=(u: ThisType[T, U])(implicit canOrder: CatalystOrdered[U]): ThisType[T, Boolean] = + typed(self.untyped <= u.untyped) + + /** + * Greater than. + * {{{ + * // The following selects people older than the maxAge column. + * df.select( df('age) > df('maxAge) ) + * }}} + * + * @param u another column of the same type + * apache/spark + */ + def >(u: ThisType[T, U])(implicit canOrder: CatalystOrdered[U]): ThisType[T, Boolean] = + typed(self.untyped > u.untyped) + + /** + * Greater than or equal. + * {{{ + * // The following selects people older or equal than the maxAge column. + * df.select( df('age) >= df('maxAge) ) + * }}} + * + * @param u another column of the same type + * apache/spark + */ + def >=(u: ThisType[T, U])(implicit canOrder: CatalystOrdered[U]): ThisType[T, Boolean] = + typed(self.untyped >= u.untyped) + + /** + * Less than. + * {{{ + * // The following selects people younger than 21. + * df.select( df('age) < 21 ) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def <(u: U)(implicit canOrder: CatalystOrdered[U]): ThisType[T, Boolean] = + typed(self.untyped < lit(u)(self.uencoder).untyped) + + /** + * Less than or equal to. + * {{{ + * // The following selects people younger than 22. + * df.select( df('age) <= 2 ) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def <=(u: U)(implicit canOrder: CatalystOrdered[U]): ThisType[T, Boolean] = + typed(self.untyped <= lit(u)(self.uencoder).untyped) + + /** + * Greater than. + * {{{ + * // The following selects people older than 21. + * df.select( df('age) > 21 ) + * }}} + * + * @param u another column of the same type + * apache/spark + */ + def >(u: U)(implicit canOrder: CatalystOrdered[U]): ThisType[T, Boolean] = + typed(self.untyped > lit(u)(self.uencoder).untyped) + + /** + * Greater than or equal. + * {{{ + * // The following selects people older than 20. + * df.select( df('age) >= 21 ) + * }}} + * + * @param u another column of the same type + * apache/spark + */ + def >=(u: U)(implicit canOrder: CatalystOrdered[U]): ThisType[T, Boolean] = + typed(self.untyped >= lit(u)(self.uencoder).untyped) } + object TypedColumn { /** * Evidence that type `T` has column `K` with type `V`. @@ -516,16 +633,4 @@ object TypedColumn { i1: Selector.Aux[H, K, V] ): Exists[T, K, V] = new Exists[T, K, V] {} } - - implicit class OrderedTypedColumnSyntax[T, U: CatalystOrdered](col: TypedColumn[T, U]) { - def <(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = (col.untyped < other.untyped).typed - def <=(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = (col.untyped <= other.untyped).typed - def >(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = (col.untyped > other.untyped).typed - def >=(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = (col.untyped >= other.untyped).typed - - def <(other: U): TypedColumn[T, Boolean] = (col.untyped < lit(other)(col.uencoder).untyped).typed - def <=(other: U): TypedColumn[T, Boolean] = (col.untyped <= lit(other)(col.uencoder).untyped).typed - def >(other: U): TypedColumn[T, Boolean] = (col.untyped > lit(other)(col.uencoder).untyped).typed - def >=(other: U): TypedColumn[T, Boolean] = (col.untyped >= lit(other)(col.uencoder).untyped).typed - } } diff --git a/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala b/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala index f25221248..79e19ff34 100644 --- a/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala +++ b/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala @@ -4,45 +4,34 @@ package functions import org.apache.spark.sql.FramelessInternals.expr import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.{functions => untyped} +import frameless.syntax._ trait AggregateFunctions { - - /** Creates a [[frameless.TypedColumn]] of literal value. If A is to be encoded using an Injection make - * sure the injection instance is in scope. - * - * apache/spark - */ - def lit[A: TypedEncoder, T](value: A): TypedColumn[T, A] = frameless.functions.lit(value) - /** Aggregate function: returns the number of items in a group. * * apache/spark */ - def count[T](): TypedAggregate[T, Long] = { - new TypedAggregate(untyped.count(untyped.lit(1))) - } + def count[T](): TypedAggregate[T, Long] = + untyped.count(untyped.lit(1)).typedAggregate /** Aggregate function: returns the number of items in a group for which the selected column is not null. * * apache/spark */ - def count[T](column: TypedColumn[T, _]): TypedAggregate[T, Long] = { - new TypedAggregate[T, Long](untyped.count(column.untyped)) - } + def count[T](column: TypedColumn[T, _]): TypedAggregate[T, Long] = + untyped.count(column.untyped).typedAggregate /** Aggregate function: returns the number of distinct items in a group. * * apache/spark */ - def countDistinct[T](column: TypedColumn[T, _]): TypedAggregate[T, Long] = { - new TypedAggregate[T, Long](untyped.countDistinct(column.untyped)) - } + def countDistinct[T](column: TypedColumn[T, _]): TypedAggregate[T, Long] = + untyped.countDistinct(column.untyped).typedAggregate /** Aggregate function: returns the approximate number of distinct items in a group. */ - def approxCountDistinct[T](column: TypedColumn[T, _]): TypedAggregate[T, Long] = { - new TypedAggregate[T, Long](untyped.approx_count_distinct(column.untyped)) - } + def approxCountDistinct[T](column: TypedColumn[T, _]): TypedAggregate[T, Long] = + untyped.approx_count_distinct(column.untyped).typedAggregate /** Aggregate function: returns the approximate number of distinct items in a group. * @@ -50,25 +39,22 @@ trait AggregateFunctions { * * apache/spark */ - def approxCountDistinct[T](column: TypedColumn[T, _], rsd: Double): TypedAggregate[T, Long] = { - new TypedAggregate[T, Long](untyped.approx_count_distinct(column.untyped, rsd)) - } + def approxCountDistinct[T](column: TypedColumn[T, _], rsd: Double): TypedAggregate[T, Long] = + untyped.approx_count_distinct(column.untyped, rsd).typedAggregate /** Aggregate function: returns a list of objects with duplicates. * * apache/spark */ - def collectList[T, A: TypedEncoder](column: TypedColumn[T, A]): TypedAggregate[T, Vector[A]] = { - new TypedAggregate[T, Vector[A]](untyped.collect_list(column.untyped)) - } + def collectList[T, A: TypedEncoder](column: TypedColumn[T, A]): TypedAggregate[T, Vector[A]] = + untyped.collect_list(column.untyped).typedAggregate /** Aggregate function: returns a set of objects with duplicate elements eliminated. * * apache/spark */ - def collectSet[T, A: TypedEncoder](column: TypedColumn[T, A]): TypedAggregate[T, Vector[A]] = { - new TypedAggregate[T, Vector[A]](untyped.collect_set(column.untyped)) - } + def collectSet[T, A: TypedEncoder](column: TypedColumn[T, A]): TypedAggregate[T, Vector[A]] = + untyped.collect_set(column.untyped).typedAggregate /** Aggregate function: returns the sum of all values in the given column. * @@ -114,7 +100,6 @@ trait AggregateFunctions { new TypedAggregate[T, Out](untyped.avg(column.untyped)) } - /** Aggregate function: returns the unbiased variance of the values in a group. * * @note In Spark variance always returns Double @@ -122,9 +107,8 @@ trait AggregateFunctions { * * apache/spark */ - def variance[A: CatalystVariance, T](column: TypedColumn[T, A]): TypedAggregate[T, Double] = { - new TypedAggregate[T, Double](untyped.variance(column.untyped)) - } + def variance[A: CatalystVariance, T](column: TypedColumn[T, A]): TypedAggregate[T, Double] = + untyped.variance(column.untyped).typedAggregate /** Aggregate function: returns the sample standard deviation. * @@ -133,9 +117,8 @@ trait AggregateFunctions { * * apache/spark */ - def stddev[A: CatalystVariance, T](column: TypedColumn[T, A]): TypedAggregate[T, Double] = { - new TypedAggregate[T, Double](untyped.stddev(column.untyped)) - } + def stddev[A: CatalystVariance, T](column: TypedColumn[T, A]): TypedAggregate[T, Double] = + untyped.stddev(column.untyped).typedAggregate /** * Aggregate function: returns the standard deviation of a column by population. @@ -175,7 +158,7 @@ trait AggregateFunctions { */ def max[A: CatalystOrdered, T](column: TypedColumn[T, A]): TypedAggregate[T, A] = { implicit val c = column.uencoder - new TypedAggregate[T, A](untyped.max(column.untyped)) + untyped.max(column.untyped).typedAggregate } /** Aggregate function: returns the minimum value of the column in a group. @@ -184,7 +167,7 @@ trait AggregateFunctions { */ def min[A: CatalystOrdered, T](column: TypedColumn[T, A]): TypedAggregate[T, A] = { implicit val c = column.uencoder - new TypedAggregate[T, A](untyped.min(column.untyped)) + untyped.min(column.untyped).typedAggregate } /** Aggregate function: returns the first value in a group. @@ -196,7 +179,7 @@ trait AggregateFunctions { */ def first[A, T](column: TypedColumn[T, A]): TypedAggregate[T, A] = { implicit val c = column.uencoder - new TypedAggregate[T, A](untyped.first(column.untyped)) + untyped.first(column.untyped).typedAggregate(column.uencoder) } /** @@ -209,7 +192,7 @@ trait AggregateFunctions { */ def last[A, T](column: TypedColumn[T, A]): TypedAggregate[T, A] = { implicit val c = column.uencoder - new TypedAggregate[T, A](untyped.last(column.untyped)) + untyped.last(column.untyped).typedAggregate } /** diff --git a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala index 211a7f8bb..61c73f5bd 100644 --- a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala +++ b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala @@ -11,10 +11,12 @@ trait NonAggregateFunctions { * * apache/spark */ - def abs[A, B, T](column: TypedColumn[T, A])(implicit evAbs: CatalystAbsolute[A, B], enc:TypedEncoder[B]):TypedColumn[T, B] = { - implicit val c = column.uencoder - new TypedColumn[T, B](untyped.abs(column.untyped)) - } + def abs[A, B, T](column: AbstractTypedColumn[T, A]) + (implicit + i0: CatalystAbsolute[A, B], + i1: TypedEncoder[B] + ): column.ThisType[T, B] = + column.typed(untyped.abs(column.untyped))(i1) /** Non-Aggregate function: returns the acos of a numeric column * @@ -22,20 +24,16 @@ trait NonAggregateFunctions { * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] * apache/spark */ - def acos[A, T](column: TypedColumn[T, A]) - (implicit evCanBeDouble: CatalystCast[A, Double]): TypedColumn[T, Double] = { - implicit val c = column.uencoder - new TypedColumn[T, Double](untyped.acos(column.cast[Double].untyped)) - } + def acos[A, T](column: AbstractTypedColumn[T, A]) + (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = + column.typed(untyped.acos(column.cast[Double].untyped)) /** Non-Aggregate function: returns true if value is contained with in the array in the specified column * * apache/spark */ - def arrayContains[C[_]: CatalystCollection, A, T](column: TypedColumn[T, C[A]], value: A): TypedColumn[T, Boolean] = { - implicit val c = column.uencoder - new TypedColumn[T, Boolean](untyped.array_contains(column.untyped, value)) - } + def arrayContains[C[_]: CatalystCollection, A, T](column: AbstractTypedColumn[T, C[A]], value: A): column.ThisType[T, Boolean] = + column.typed(untyped.array_contains(column.untyped, value)) /** Non-Aggregate function: returns the atan of a numeric column * @@ -43,11 +41,9 @@ trait NonAggregateFunctions { * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] * apache/spark */ - def atan[A, T](column: TypedColumn[T,A]) - (implicit evCanBeDouble: CatalystCast[A, Double]): TypedColumn[T, Double] = { - implicit val c = column.uencoder - new TypedColumn[T, Double](untyped.atan(column.cast[Double].untyped)) - } + def atan[A, T](column: AbstractTypedColumn[T,A]) + (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = + column.typed(untyped.atan(column.cast[Double].untyped)) /** Non-Aggregate function: returns the asin of a numeric column * @@ -55,11 +51,9 @@ trait NonAggregateFunctions { * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] * apache/spark */ - def asin[A, T](column: TypedColumn[T, A]) - (implicit evCanBeDouble: CatalystCast[A, Double]): TypedColumn[T, Double] = { - implicit val c = column.uencoder - new TypedColumn[T, Double](untyped.asin(column.cast[Double].untyped)) - } + def asin[A, T](column: AbstractTypedColumn[T, A]) + (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = + column.typed(untyped.asin(column.cast[Double].untyped)) /** Non-Aggregate function: returns the angle theta from the conversion of rectangular coordinates (x, y) to * polar coordinates (r, theta). @@ -68,40 +62,57 @@ trait NonAggregateFunctions { * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] * apache/spark */ - def atan2[A, B, T](l: TypedColumn[T,A], r: TypedColumn[T, B]) + def atan2[A, B, T](l: TypedColumn[T, A], r: TypedColumn[T, B]) (implicit - evCanBeDoubleL: CatalystCast[A, Double], - evCanBeDoubleR: CatalystCast[B, Double] - ): TypedColumn[T, Double] = { - implicit val lUnencoder = l.uencoder - implicit val rUnencoder = r.uencoder - new TypedColumn[T, Double](untyped.atan2(l.cast[Double].untyped, r.cast[Double].untyped)) - } + i0: CatalystCast[A, Double], + i1: CatalystCast[B, Double] + ): TypedColumn[T, Double] = + r.typed(untyped.atan2(l.cast[Double].untyped, r.cast[Double].untyped)) - def atan2[B, T](l: Double, r: TypedColumn[T, B])(implicit evCanBeDoubleR: CatalystCast[B, Double]): TypedColumn[T, Double] = - atan2(lit(l): TypedColumn[T, Double], r) + /** Non-Aggregate function: returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def atan2[A, B, T](l: TypedAggregate[T, A], r: TypedAggregate[T, B]) + (implicit + i0: CatalystCast[A, Double], + i1: CatalystCast[B, Double] + ): TypedAggregate[T, Double] = + r.typed(untyped.atan2(l.cast[Double].untyped, r.cast[Double].untyped)) - def atan2[A, T](l: TypedColumn[T, A], r: Double)(implicit evCanBeDoubleL: CatalystCast[A, Double]): TypedColumn[T, Double] = - atan2(l, lit(r): TypedColumn[T, Double]) + def atan2[B, T](l: Double, r: TypedColumn[T, B]) + (implicit i0: CatalystCast[B, Double]): TypedColumn[T, Double] = + atan2(r.lit(l), r) + + def atan2[A, T](l: TypedColumn[T, A], r: Double) + (implicit i0: CatalystCast[A, Double]): TypedColumn[T, Double] = + atan2(l, l.lit(r)) + + def atan2[B, T](l: Double, r: TypedAggregate[T, B]) + (implicit i0: CatalystCast[B, Double]): TypedAggregate[T, Double] = + atan2(r.lit(l), r) + + def atan2[A, T](l: TypedAggregate[T, A], r: Double) + (implicit i0: CatalystCast[A, Double]): TypedAggregate[T, Double] = + atan2(l, l.lit(r)) /** Non-Aggregate function: Returns the string representation of the binary value of the given long * column. For example, bin("12") returns "1100". * * apache/spark */ - def bin[T](column: TypedColumn[T, Long]): TypedColumn[T, String] = { - implicit val c = column.uencoder - new TypedColumn[T, String](untyped.bin(column.untyped)) - } + def bin[T](column: AbstractTypedColumn[T, Long]): column.ThisType[T, String] = + column.typed(untyped.bin(column.untyped)) /** Non-Aggregate function: Computes bitwise NOT. * * apache/spark */ - def bitwiseNOT[A: CatalystBitwise, T](column: TypedColumn[T, A]): TypedColumn[T, A] = { - implicit val c = column.uencoder - new TypedColumn[T, A](untyped.bitwiseNOT(column.untyped)) - } + def bitwiseNOT[A: CatalystBitwise, T](column: AbstractTypedColumn[T, A]): column.ThisType[T, A] = + column.typed(untyped.bitwiseNOT(column.untyped))(column.uencoder) /** Non-Aggregate function: file name of the current Spark task. Empty string if row did not originate from * a file @@ -129,19 +140,18 @@ trait NonAggregateFunctions { * }}} * apache/spark */ - def when[T, A](condition: TypedColumn[T, Boolean], value: TypedColumn[T, A]): When[T, A] = + def when[T, A](condition: AbstractTypedColumn[T, Boolean], value: AbstractTypedColumn[T, A]): When[T, A] = new When[T, A](condition, value) class When[T, A] private (untypedC: Column) { - private[functions] def this(condition: TypedColumn[T, Boolean], value: TypedColumn[T, A]) = + private[functions] def this(condition: AbstractTypedColumn[T, Boolean], value: AbstractTypedColumn[T, A]) = this(untyped.when(condition.untyped, value.untyped)) - def when(condition: TypedColumn[T, Boolean], value: TypedColumn[T, A]): When[T, A] = new When[T, A]( - untypedC.when(condition.untyped, value.untyped) - ) + def when(condition: AbstractTypedColumn[T, Boolean], value: AbstractTypedColumn[T, A]): When[T, A] = + new When[T, A](untypedC.when(condition.untyped, value.untyped)) - def otherwise(value: TypedColumn[T, A]): TypedColumn[T, A] = - new TypedColumn[T, A](untypedC.otherwise(value.untyped))(value.uencoder) + def otherwise(value: AbstractTypedColumn[T, A]): value.ThisType[T, A] = + value.typed(untypedC.otherwise(value.untyped))(value.uencoder) } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -153,35 +163,50 @@ trait NonAggregateFunctions { * * apache/spark */ - def ascii[T](column: TypedColumn[T, String]): TypedColumn[T, Int] = { - new TypedColumn[T, Int](untyped.ascii(column.untyped)) - } + def ascii[T](column: AbstractTypedColumn[T, String]): column.ThisType[T, Int] = + column.typed(untyped.ascii(column.untyped)) /** Non-Aggregate function: Computes the BASE64 encoding of a binary column and returns it as a string column. * This is the reverse of unbase64. * * apache/spark */ - def base64[T](column: TypedColumn[T, Array[Byte]]): TypedColumn[T, String] = { - new TypedColumn[T, String](untyped.base64(column.untyped)) - } + def base64[T](column: AbstractTypedColumn[T, Array[Byte]]): column.ThisType[T, String] = + column.typed(untyped.base64(column.untyped)) /** Non-Aggregate function: Concatenates multiple input string columns together into a single string column. + * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] * * apache/spark */ - def concat[T](columns: TypedColumn[T, String]*): TypedColumn[T, String] = { - new TypedColumn[T, String](untyped.concat(columns.map(_.untyped):_*)) - } + def concat[T](columns: TypedColumn[T, String]*): TypedColumn[T, String] = + new TypedColumn(untyped.concat(columns.map(_.untyped): _*)) + + /** Non-Aggregate function: Concatenates multiple input string columns together into a single string column. + * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] + * + * apache/spark + */ + def concat[T](columns: TypedAggregate[T, String]*): TypedAggregate[T, String] = + new TypedAggregate(untyped.concat(columns.map(_.untyped): _*)) /** Non-Aggregate function: Concatenates multiple input string columns together into a single string column, * using the given separator. + * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] * * apache/spark */ - def concatWs[T](sep: String, columns: TypedColumn[T, String]*): TypedColumn[T, String] = { - new TypedColumn[T, String](untyped.concat_ws(sep, columns.map(_.untyped):_*)) - } + def concatWs[T](sep: String, columns: TypedAggregate[T, String]*): TypedAggregate[T, String] = + new TypedAggregate(untyped.concat_ws(sep, columns.map(_.untyped): _*)) + + /** Non-Aggregate function: Concatenates multiple input string columns together into a single string column, + * using the given separator. + * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] + * + * apache/spark + */ + def concatWs[T](sep: String, columns: TypedColumn[T, String]*): TypedColumn[T, String] = + new TypedColumn(untyped.concat_ws(sep, columns.map(_.untyped): _*)) /** Non-Aggregate function: Locates the position of the first occurrence of substring column * in given string @@ -191,107 +216,106 @@ trait NonAggregateFunctions { * * apache/spark */ - def instr[T](column: TypedColumn[T, String], substring: String): TypedColumn[T, Int] = { - new TypedColumn[T, Int](untyped.instr(column.untyped, substring)) - } + def instr[T](str: AbstractTypedColumn[T, String], substring: String): str.ThisType[T, Int] = + str.typed(untyped.instr(str.untyped, substring)) /** Non-Aggregate function: Computes the length of a given string. * * apache/spark */ //TODO: Also for binary - def length[T](column: TypedColumn[T, String]): TypedColumn[T, Int] = { - new TypedColumn[T, Int](untyped.length(column.untyped)) - } + def length[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Int] = + str.typed(untyped.length(str.untyped)) /** Non-Aggregate function: Computes the Levenshtein distance of the two given string columns. * * apache/spark */ - def levenshtein[T](l: TypedColumn[T, String], r: TypedColumn[T, String]): TypedColumn[T, Int] = { - new TypedColumn[T, Int](untyped.levenshtein(l.untyped, r.untyped)) - } + def levenshtein[T](l: TypedColumn[T, String], r: TypedColumn[T, String]): TypedColumn[T, Int] = + l.typed(untyped.levenshtein(l.untyped, r.untyped)) + + /** Non-Aggregate function: Computes the Levenshtein distance of the two given string columns. + * + * apache/spark + */ + def levenshtein[T](l: TypedAggregate[T, String], r: TypedAggregate[T, String]): TypedAggregate[T, Int] = + l.typed(untyped.levenshtein(l.untyped, r.untyped)) /** Non-Aggregate function: Converts a string column to lower case. * * apache/spark */ - def lower[T](e: TypedColumn[T, String]): TypedColumn[T, String] = { - new TypedColumn[T, String](untyped.lower(e.untyped)) - } + def lower[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = + str.typed(untyped.lower(str.untyped)) /** Non-Aggregate function: Left-pad the string column with pad to a length of len. If the string column is longer * than len, the return value is shortened to len characters. * * apache/spark */ - def lpad[T](str: TypedColumn[T, String], len: Int, pad: String): TypedColumn[T, String] = { - new TypedColumn[T, String](untyped.lpad(str.untyped, len, pad)) - } + def lpad[T](str: AbstractTypedColumn[T, String], + len: Int, + pad: String): str.ThisType[T, String] = + str.typed(untyped.lpad(str.untyped, len, pad)) /** Non-Aggregate function: Trim the spaces from left end for the specified string value. * * apache/spark */ - def ltrim[T](str: TypedColumn[T, String]): TypedColumn[T, String] = { - new TypedColumn[T, String](untyped.ltrim(str.untyped)) - } + def ltrim[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = + str.typed(untyped.ltrim(str.untyped)) /** Non-Aggregate function: Replace all substrings of the specified string value that match regexp with rep. * * apache/spark */ - def regexpReplace[T](str: TypedColumn[T, String], pattern: Regex, replacement: String): TypedColumn[T, String] = { - new TypedColumn[T, String](untyped.regexp_replace(str.untyped, pattern.regex, replacement)) - } + def regexpReplace[T](str: AbstractTypedColumn[T, String], + pattern: Regex, + replacement: String): str.ThisType[T, String] = + str.typed(untyped.regexp_replace(str.untyped, pattern.regex, replacement)) + /** Non-Aggregate function: Reverses the string column and returns it as a new string column. * * apache/spark */ - def reverse[T](str: TypedColumn[T, String]): TypedColumn[T, String] = { - new TypedColumn[T, String](untyped.reverse(str.untyped)) - } + def reverse[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = + str.typed(untyped.reverse(str.untyped)) /** Non-Aggregate function: Right-pad the string column with pad to a length of len. * If the string column is longer than len, the return value is shortened to len characters. * * apache/spark */ - def rpad[T](str: TypedColumn[T, String], len: Int, pad: String): TypedColumn[T, String] = { - new TypedColumn[T, String](untyped.rpad(str.untyped, len, pad)) - } + def rpad[T](str: AbstractTypedColumn[T, String], len: Int, pad: String): str.ThisType[T, String] = + str.typed(untyped.rpad(str.untyped, len, pad)) /** Non-Aggregate function: Trim the spaces from right end for the specified string value. * * apache/spark */ - def rtrim[T](e: TypedColumn[T, String]): TypedColumn[T, String] = { - new TypedColumn[T, String](untyped.rtrim(e.untyped)) - } + def rtrim[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = + str.typed(untyped.rtrim(str.untyped)) /** Non-Aggregate function: Substring starts at `pos` and is of length `len` * * apache/spark */ //TODO: Also for byte array - def substring[T](str: TypedColumn[T, String], pos: Int, len: Int): TypedColumn[T, String] = { - new TypedColumn[T, String](untyped.substring(str.untyped, pos, len)) - } + def substring[T](str: AbstractTypedColumn[T, String], pos: Int, len: Int): str.ThisType[T, String] = + str.typed(untyped.substring(str.untyped, pos, len)) /** Non-Aggregate function: Trim the spaces from both ends for the specified string column. * * apache/spark */ - def trim[T](str: TypedColumn[T, String]): TypedColumn[T, String] = { - new TypedColumn[T, String](untyped.trim(str.untyped)) - } + def trim[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = + str.typed(untyped.trim(str.untyped)) /** Non-Aggregate function: Converts a string column to upper case. * * apache/spark */ - def upper[T](str: TypedColumn[T, String]): TypedColumn[T, String] = { - new TypedColumn[T, String](untyped.upper(str.untyped)) - } + def upper[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = + str.typed(untyped.upper(str.untyped)) } diff --git a/dataset/src/main/scala/frameless/functions/package.scala b/dataset/src/main/scala/frameless/functions/package.scala index 286deadbd..f1e72a0e6 100644 --- a/dataset/src/main/scala/frameless/functions/package.scala +++ b/dataset/src/main/scala/frameless/functions/package.scala @@ -7,6 +7,20 @@ package object functions extends Udf with UnaryFunctions { object aggregate extends AggregateFunctions object nonAggregate extends NonAggregateFunctions + /** Creates a [[frameless.TypedAggregate]] of literal value. If A is to be encoded using an Injection make + * sure the injection instance is in scope. + * + * apache/spark + */ + def litAggr[A: TypedEncoder, T](value: A): TypedAggregate[T, A] = + new TypedAggregate[T,A](lit(value).expr) + + + /** Creates a [[frameless.TypedColumn]] of literal value. If A is to be encoded using an Injection make + * sure the injection instance is in scope. + * + * apache/spark + */ def lit[A: TypedEncoder, T](value: A): TypedColumn[T, A] = { val encoder = TypedEncoder[A] diff --git a/dataset/src/test/scala/frameless/FilterTests.scala b/dataset/src/test/scala/frameless/FilterTests.scala index 1cce64ed4..c068f13f3 100644 --- a/dataset/src/test/scala/frameless/FilterTests.scala +++ b/dataset/src/test/scala/frameless/FilterTests.scala @@ -33,6 +33,27 @@ class FilterTests extends TypedDatasetSuite { check(forAll(prop[Int] _)) check(forAll(prop[String] _)) check(forAll(prop[Char] _)) + check(forAll(prop[Boolean] _)) + check(forAll(prop[SQLTimestamp] _)) + //check(forAll(prop[Vector[SQLTimestamp]] _)) // Commenting out since this fails randomly due to frameless Issue #124 + } + + test("filter('a =!= 'b)") { + def prop[A: TypedEncoder](data: Vector[X2[A, A]]): Prop = { + val dataset = TypedDataset.create(data) + val A = dataset.col('a) + val B = dataset.col('b) + + val dataset2 = dataset.filter(A =!= B).collect().run().toVector + val data2 = data.filter(x => x.a != x.b) + + dataset2 ?= data2 + } + + check(forAll(prop[Int] _)) + check(forAll(prop[String] _)) + check(forAll(prop[Char] _)) + check(forAll(prop[Boolean] _)) check(forAll(prop[SQLTimestamp] _)) //check(forAll(prop[Vector[SQLTimestamp]] _)) // Commenting out since this fails randomly due to frameless Issue #124 } diff --git a/dataset/src/test/scala/frameless/SchemaTests.scala b/dataset/src/test/scala/frameless/SchemaTests.scala index 21fa363de..762ab19b2 100644 --- a/dataset/src/test/scala/frameless/SchemaTests.scala +++ b/dataset/src/test/scala/frameless/SchemaTests.scala @@ -1,6 +1,7 @@ package frameless import frameless.functions.aggregate._ +import frameless.functions._ import org.scalacheck.Prop import org.scalacheck.Prop._ import org.scalatest.Matchers diff --git a/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala b/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala index fa19f2f6b..8a5479eb6 100644 --- a/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala +++ b/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala @@ -155,6 +155,17 @@ class AggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Double] _)) } + test("litAggr") { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder](xs: List[A], b: B, c: C): Prop = { + val dataset = TypedDataset.create(xs) + val (r1, rb, rc, rcount) = dataset.agg(count().lit(1), litAggr(b), litAggr(c), count()).collect().run().head + (rcount ?= xs.size.toLong) && (r1 ?= 1) && (rb ?= b) && (rc ?= c) + } + + check(forAll(prop[Boolean, Int, String] _)) + check(forAll(prop[Option[Boolean], Vector[Option[Vector[Char]]], Long] _)) + } + test("count") { def prop[A: TypedEncoder](xs: List[A]): Prop = { val dataset = TypedDataset.create(xs) @@ -197,6 +208,18 @@ class AggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[String] _)) } + test("max with follow up multiplication") { + def prop(xs: List[Long]): Prop = { + val dataset = TypedDataset.create(xs.map(X1(_))) + val A = dataset.col[Long]('a) + val datasetMax = dataset.agg(max(A) * 2).collect().run().headOption + + datasetMax ?= (if(xs.isEmpty) None else Some(xs.max * 2)) + } + + check(forAll(prop _)) + } + test("min") { def prop[A: TypedEncoder: CatalystOrdered](xs: List[A])(implicit o: Ordering[A]): Prop = { val dataset = TypedDataset.create(xs.map(X1(_))) diff --git a/dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala b/dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala index 3bf5a3634..f3a8be581 100644 --- a/dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala +++ b/dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala @@ -1,11 +1,15 @@ -package frameless.functions +package frameless +package functions /** - * Some statistical functions in Spark can result in Double, Double.NaN or Null. This tends to break ?= of the property based testing. - * Use the nanNullHandler function here to alleviate this by mapping this NaN and Null to None. This will result in functioning comparison again. + * Some statistical functions in Spark can result in Double, Double.NaN or Null. + * This tends to break ?= of the property based testing. Use the nanNullHandler function + * here to alleviate this by mapping this NaN and Null to None. This will result in + * functioning comparison again. */ object DoubleBehaviourUtils { - // Mapping with this function is needed because spark uses Double.NaN for some semantics in the correlation function. ?= for prop testing will use == underlying and will break because Double.NaN != Double.NaN + // Mapping with this function is needed because spark uses Double.NaN for some semantics in the + // correlation function. ?= for prop testing will use == underlying and will break because Double.NaN != Double.NaN private val nanHandler: Double => Option[Double] = value => if (!value.equals(Double.NaN)) Option(value) else None // Making sure that null => None and does not result in 0.0d because of row.getAs[Double]'s use of .asInstanceOf val nanNullHandler: Any => Option[Double] = { @@ -13,5 +17,4 @@ object DoubleBehaviourUtils { case d: Double => nanHandler(d) case _ => ??? } - } diff --git a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala index 961620c7e..2640fb89c 100644 --- a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala +++ b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala @@ -1,10 +1,11 @@ package frameless package functions + import java.io.File import frameless.functions.nonAggregate._ import org.apache.commons.io.FileUtils -import org.apache.spark.sql.{ Column, Encoder, SaveMode, functions => untyped } +import org.apache.spark.sql.{Encoder, SaveMode, functions => untyped} import org.scalacheck.Gen import org.scalacheck.Prop._ @@ -54,35 +55,33 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ def prop[A: TypedEncoder : Encoder] - (values: List[X1[A]]) - ( - implicit catalystAbsolute: CatalystAbsolute[A, A], - encX1:Encoder[X1[A]] - )= { - val cDS = session.createDataset(values) - val resCompare = cDS - .select(org.apache.spark.sql.functions.abs(cDS("a"))) - .map(_.getAs[A](0)) - .collect().toList + (values: List[X1[A]]) + ( + implicit catalystAbsolute: CatalystAbsolute[A, A], + encX1: Encoder[X1[A]] + ) = { + val cDS = session.createDataset(values) + val resCompare = cDS + .select(org.apache.spark.sql.functions.abs(cDS("a"))) + .map(_.getAs[A](0)) + .collect().toList - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(abs(typedDS('a))) - .collect() - .run() - .toList - - res ?= resCompare - } + val typedDS = TypedDataset.create(values) + val res = typedDS + .select(abs(typedDS('a))) + .collect() + .run() + .toList - check(forAll(prop[Int] _)) - check(forAll(prop[Long] _)) - check(forAll(prop[Short] _)) - check(forAll(prop[Double] _)) + res ?= resCompare } - + check(forAll(prop[Int] _)) + check(forAll(prop[Long] _)) + check(forAll(prop[Short] _)) + check(forAll(prop[Double] _)) + } test("acos") { val spark = session @@ -112,7 +111,6 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { res ?= resCompare } - check(forAll(prop[Int] _)) check(forAll(prop[Long] _)) check(forAll(prop[Short] _)) @@ -130,8 +128,6 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { * [[https://issues.apache.org/jira/browse/SPARK-21204]] */ test("arrayContains"){ - - val spark = session import spark.implicits._ @@ -174,7 +170,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val cDS = session.createDataset(List(values)) val resCompare = cDS - .select(org.apache.spark.sql.functions.array_contains(cDS("value"), contained)) + .select(untyped.array_contains(cDS("value"), contained)) .map(_.getAs[Boolean](0)) .collect().toList @@ -227,16 +223,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { - val cDS = session.createDataset(values) + def prop[A: CatalystNumeric : TypedEncoder : Encoder] + (na: A, values: List[X1[A]])(implicit encX1: Encoder[X1[A]]) = { + val cDS = session.createDataset(X1(na) :: values) val resCompare = cDS - .select(org.apache.spark.sql.functions.atan(cDS("a"))) + .select(untyped.atan(cDS("a"))) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) .collect().toList - - val typedDS = TypedDataset.create(values) + val typedDS = TypedDataset.create(cDS) val res = typedDS .select(atan(typedDS('a))) .deserialized @@ -245,9 +241,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .run() .toList - res ?= resCompare - } + val aggrTyped = typedDS.agg(atan( + frameless.functions.aggregate.first(typedDS('a))) + ).firstOption().run().get + val aggrSpark = cDS.select( + untyped.atan(untyped.first("a")).as[Double] + ).first() + + (res ?= resCompare).&&(aggrTyped ?= aggrSpark) + } check(forAll(prop[Int] _)) check(forAll(prop[Long] _)) @@ -261,16 +264,17 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric : TypedEncoder : Encoder] + (values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { val cDS = session.createDataset(values) val resCompare = cDS - .select(org.apache.spark.sql.functions.asin(cDS("a"))) + .select(untyped.asin(cDS("a"))) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) .collect().toList - val typedDS = TypedDataset.create(values) + val typedDS = TypedDataset.create(cDS) val res = typedDS .select(asin(typedDS('a))) .deserialized @@ -295,17 +299,18 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder, B: CatalystNumeric : TypedEncoder : Encoder](values: List[X2[A, B]]) + def prop[A: CatalystNumeric : TypedEncoder : Encoder, + B: CatalystNumeric : TypedEncoder : Encoder](na: X2[A, B], values: List[X2[A, B]]) (implicit encEv: Encoder[X2[A,B]]) = { - val cDS = session.createDataset(values) + val cDS = session.createDataset(na +: values) val resCompare = cDS - .select(org.apache.spark.sql.functions.atan2(cDS("a"), cDS("b"))) + .select(untyped.atan2(cDS("a"), cDS("b"))) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) .collect().toList - val typedDS = TypedDataset.create(values) + val typedDS = TypedDataset.create(cDS) val res = typedDS .select(atan2(typedDS('a), typedDS('b))) .deserialized @@ -314,7 +319,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .run() .toList - res ?= resCompare + val aggrTyped = typedDS.agg(atan2( + frameless.functions.aggregate.first(typedDS('a)), + frameless.functions.aggregate.first(typedDS('b))) + ).firstOption().run().get + + val aggrSpark = cDS.select( + untyped.atan2(untyped.first("a"),untyped.first("b")).as[Double] + ).first() + + (res ?= resCompare).&&(aggrTyped ?= aggrSpark) } @@ -330,16 +344,17 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](value: List[X1[A]], lit:Double)(implicit encX1:Encoder[X1[A]]) = { - val cDS = session.createDataset(value) + def prop[A: CatalystNumeric : TypedEncoder : Encoder] + (na: X1[A], value: List[X1[A]], lit:Double)(implicit encX1:Encoder[X1[A]]) = { + val cDS = session.createDataset(na +: value) val resCompare = cDS - .select(org.apache.spark.sql.functions.atan2(lit, cDS("a"))) + .select(untyped.atan2(lit, cDS("a"))) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) .collect().toList - val typedDS = TypedDataset.create(value) + val typedDS = TypedDataset.create(cDS) val res = typedDS .select(atan2(lit, typedDS('a))) .deserialized @@ -348,7 +363,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .run() .toList - res ?= resCompare + val aggrTyped = typedDS.agg(atan2( + lit, + frameless.functions.aggregate.first(typedDS('a))) + ).firstOption().run().get + + val aggrSpark = cDS.select( + untyped.atan2(lit, untyped.first("a")).as[Double] + ).first() + + (res ?= resCompare).&&(aggrTyped ?= aggrSpark) } @@ -364,16 +388,17 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](value: List[X1[A]], lit:Double)(implicit encX1:Encoder[X1[A]]) = { - val cDS = session.createDataset(value) + def prop[A: CatalystNumeric : TypedEncoder : Encoder] + (na: X1[A], value: List[X1[A]], lit:Double)(implicit encX1:Encoder[X1[A]]) = { + val cDS = session.createDataset(na +: value) val resCompare = cDS - .select(org.apache.spark.sql.functions.atan2(cDS("a"), lit)) + .select(untyped.atan2(cDS("a"), lit)) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) .collect().toList - val typedDS = TypedDataset.create(value) + val typedDS = TypedDataset.create(cDS) val res = typedDS .select(atan2(typedDS('a), lit)) .deserialized @@ -382,7 +407,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .run() .toList - res ?= resCompare + val aggrTyped = typedDS.agg(atan2( + frameless.functions.aggregate.first(typedDS('a)), + lit) + ).firstOption().run().get + + val aggrSpark = cDS.select( + untyped.atan2(untyped.first("a"), lit).as[Double] + ).first() + + (res ?= resCompare).&&(aggrTyped ?= aggrSpark) } @@ -401,7 +435,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { def prop(values:List[X1[Array[Byte]]])(implicit encX1:Encoder[X1[Array[Byte]]]) = { val cDS = session.createDataset(values) val resCompare = cDS - .select(org.apache.spark.sql.functions.base64(cDS("a"))) + .select(untyped.base64(cDS("a"))) .map(_.getAs[String](0)) .collect().toList @@ -425,7 +459,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { def prop(values:List[X1[Long]])(implicit encX1:Encoder[X1[Long]]) = { val cDS = session.createDataset(values) val resCompare = cDS - .select(org.apache.spark.sql.functions.bin(cDS("a"))) + .select(untyped.bin(cDS("a"))) .map(_.getAs[String](0)) .collect().toList @@ -442,15 +476,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop _)) } - test("bitwiseNOT"){ val spark = session import spark.implicits._ - def prop[A: CatalystBitwise : TypedEncoder : Encoder](values:List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystBitwise : TypedEncoder : Encoder] + (values:List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { val cDS = session.createDataset(values) val resCompare = cDS - .select(org.apache.spark.sql.functions.bitwiseNOT(cDS("a"))) + .select(untyped.bitwiseNOT(cDS("a"))) .map(_.getAs[A](0)) .collect().toList @@ -540,7 +574,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A : TypedEncoder : Encoder](condition1: Boolean, condition2: Boolean, value1: A, value2: A, otherwise: A) = { + def prop[A : TypedEncoder : Encoder] + (condition1: Boolean, condition2: Boolean, value1: A, value2: A, otherwise: A) = { val ds = TypedDataset.create(X5(condition1, condition2, value1, value2, otherwise) :: Nil) val untypedWhen = ds.toDF() @@ -576,132 +611,429 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { test("ascii") { val spark = session import spark.implicits._ + check(forAll { values: List[X1[String]] => + val ds = TypedDataset.create(values) + + val sparkResult = ds.toDF() + .select(untyped.ascii($"a")) + .map(_.getAs[Int](0)) + .collect() + .toVector + + val typed = ds + .select(ascii(ds('a))) + .collect() + .run() + .toVector - check(stringFuncProp(ascii, untyped.ascii)) + typed ?= sparkResult + }) } test("concat") { val spark = session import spark.implicits._ - check(stringFuncProp(concat(_, lit("hello")), untyped.concat(_, untyped.lit("hello")))) - } + val pairs = for { + y <- Gen.alphaStr + x <- Gen.nonEmptyListOf(X2(y, y)) + } yield x - test("concat_ws") { - val spark = session - import spark.implicits._ + check(forAll(pairs) { values: List[X2[String, String]] => + val ds = TypedDataset.create(values) - check(stringFuncProp(concatWs(",", _, lit("hello")), untyped.concat_ws(",", _, untyped.lit("hello")))) + val sparkResult = ds.toDF() + .select(untyped.concat($"a", $"b")) + .map(_.getAs[String](0)) + .collect() + .toVector + + val typed = ds + .select(concat(ds('a), ds('b))) + .collect() + .run() + .toVector + + (typed ?= sparkResult).&&(typed ?= values.map(x => s"${x.a}${x.b}").toVector) + }) } - test("instr") { + test("concat for TypedAggregate") { val spark = session import spark.implicits._ - check(stringFuncProp(instr(_, "hello"), untyped.instr(_, "hello"))) + import frameless.functions.aggregate._ + val pairs = for { + y <- Gen.alphaStr + x <- Gen.nonEmptyListOf(X2(y, y)) + } yield x + + check(forAll(pairs) { values: List[X2[String, String]] => + val ds = TypedDataset.create(values) + val td = ds.agg(concat(first(ds('a)),first(ds('b)))).collect().run().toVector + val spark = ds.dataset.select(untyped.concat( + untyped.first($"a").as[String], + untyped.first($"b").as[String])).as[String].collect().toVector + td ?= spark + }) } - test("length") { + test("concat_ws") { val spark = session import spark.implicits._ - check(stringFuncProp(length, untyped.length)) + val pairs = for { + y <- Gen.alphaStr + x <- Gen.nonEmptyListOf(X2(y, y)) + } yield x + + check(forAll(pairs) { values: List[X2[String, String]] => + val ds = TypedDataset.create(values) + + val sparkResult = ds.toDF() + .select(untyped.concat_ws(",", $"a", $"b")) + .map(_.getAs[String](0)) + .collect() + .toVector + + val typed = ds + .select(concatWs(",", ds('a), ds('b))) + .collect() + .run() + .toVector + + typed ?= sparkResult + }) } - test("levenshtein") { + test("concat_ws for TypedAggregate") { val spark = session import spark.implicits._ - check(stringFuncProp(levenshtein(_, lit("hello")), untyped.levenshtein(_, untyped.lit("hello")))) + import frameless.functions.aggregate._ + val pairs = for { + y <- Gen.alphaStr + x <- Gen.listOfN(10, X2(y, y)) + } yield x + + check(forAll(pairs) { values: List[X2[String, String]] => + val ds = TypedDataset.create(values) + val td = ds.agg(concatWs(",",first(ds('a)),first(ds('b)), last(ds('b)))).collect().run().toVector + val spark = ds.dataset.select(untyped.concat_ws(",", + untyped.first($"a").as[String], + untyped.first($"b").as[String], + untyped.last($"b").as[String])).as[String].collect().toVector + td ?= spark + }) } - test("lower") { + test("instr") { val spark = session import spark.implicits._ + check(forAll(Gen.nonEmptyListOf(Gen.alphaStr)) { values: List[String] => + val ds = TypedDataset.create(values.map(x => X1(x + values.head))) + + val sparkResult = ds.toDF() + .select(untyped.instr($"a", values.head)) + .map(_.getAs[Int](0)) + .collect() + .toVector + + val typed = ds + .select(instr(ds('a), values.head)) + .collect() + .run() + .toVector - check(stringFuncProp(lower, untyped.lower)) + typed ?= sparkResult + }) } - test("lpad") { + test("length") { val spark = session import spark.implicits._ + check(forAll { values: List[X1[String]] => + val ds = TypedDataset.create(values) + + val sparkResult = ds.toDF() + .select(untyped.length($"a")) + .map(_.getAs[Int](0)) + .collect() + .toVector - check(stringFuncProp(lpad(_, 5, "hello"), untyped.lpad(_, 5, "hello"))) + val typed = ds + .select(length(ds[String]('a))) + .collect() + .run() + .toVector + + (typed ?= sparkResult).&&(values.map(_.a.length).toVector ?= typed) + }) } - test("ltrim") { + test("levenshtein") { val spark = session import spark.implicits._ + check(forAll { (na: X1[String], values: List[X1[String]]) => + val ds = TypedDataset.create(na +: values) + + val sparkResult = ds.toDF() + .select(untyped.levenshtein($"a", untyped.concat($"a",untyped.lit("Hello")))) + .map(_.getAs[Int](0)) + .collect() + .toVector + + val typed = ds + .select(levenshtein(ds('a), concat(ds('a),lit("Hello")))) + .collect() + .run() + .toVector + + val cDS = ds.dataset + val aggrTyped = ds.agg( + levenshtein(frameless.functions.aggregate.first(ds('a)), litAggr("Hello")) + ).firstOption().run().get + + val aggrSpark = cDS.select( + untyped.levenshtein(untyped.first("a"), untyped.lit("Hello")).as[Int] + ).first() - check(stringFuncProp(ltrim, untyped.ltrim)) + (typed ?= sparkResult).&&(aggrTyped ?= aggrSpark) + }) } test("regexp_replace") { val spark = session import spark.implicits._ + check(forAll { (values: List[X1[String]], n: Int) => + val ds = TypedDataset.create(values.map(x => X1(s"$n${x.a}-$n$n"))) + + val sparkResult = ds.toDF() + .select(untyped.regexp_replace($"a", "\\d+", "n")) + .map(_.getAs[String](0)) + .collect() + .toVector - check(stringFuncProp(regexpReplace(_, "\\d+".r, "n"), untyped.regexp_replace(_, "\\d+", "n"))) + val typed = ds + .select(regexpReplace(ds[String]('a), "\\d+".r, "n")) + .collect() + .run() + .toVector + + typed ?= sparkResult + }) } test("reverse") { val spark = session import spark.implicits._ + check(forAll { values: List[X1[String]] => + val ds = TypedDataset.create(values) - check(stringFuncProp(reverse, untyped.reverse)) + val sparkResult = ds.toDF() + .select(untyped.reverse($"a")) + .map(_.getAs[String](0)) + .collect() + .toVector + + val typed = ds + .select(reverse(ds[String]('a))) + .collect() + .run() + .toVector + + (typed ?= sparkResult).&&(values.map(_.a.reverse).toVector ?= typed) + }) } test("rpad") { val spark = session import spark.implicits._ + check(forAll { values: List[X1[String]] => + val ds = TypedDataset.create(values) - check(stringFuncProp(rpad(_, 5, "hello"), untyped.rpad(_, 5, "hello"))) + val sparkResult = ds.toDF() + .select(untyped.rpad($"a", 5, "hello")) + .map(_.getAs[String](0)) + .collect() + .toVector + + val typed = ds + .select(rpad(ds[String]('a), 5, "hello")) + .collect() + .run() + .toVector + + typed ?= sparkResult + }) + } + + test("lpad") { + val spark = session + import spark.implicits._ + check(forAll { values: List[X1[String]] => + val ds = TypedDataset.create(values) + + val sparkResult = ds.toDF() + .select(untyped.lpad($"a", 5, "hello")) + .map(_.getAs[String](0)) + .collect() + .toVector + + val typed = ds + .select(lpad(ds[String]('a), 5, "hello")) + .collect() + .run() + .toVector + + typed ?= sparkResult + }) } test("rtrim") { val spark = session import spark.implicits._ + check(forAll { values: List[X1[String]] => + val ds = TypedDataset.create(values.map(x => X1(s" ${x.a} "))) - check(stringFuncProp(rtrim, untyped.rtrim)) + val sparkResult = ds.toDF() + .select(untyped.rtrim($"a")) + .map(_.getAs[String](0)) + .collect() + .toVector + + val typed = ds + .select(rtrim(ds[String]('a))) + .collect() + .run() + .toVector + + typed ?= sparkResult + }) + } + + test("ltrim") { + val spark = session + import spark.implicits._ + check(forAll { values: List[X1[String]] => + val ds = TypedDataset.create(values.map(x => X1(s" ${x.a} "))) + + val sparkResult = ds.toDF() + .select(untyped.ltrim($"a")) + .map(_.getAs[String](0)) + .collect() + .toVector + + val typed = ds + .select(ltrim(ds[String]('a))) + .collect() + .run() + .toVector + + typed ?= sparkResult + }) } test("substring") { val spark = session import spark.implicits._ + check(forAll { values: List[X1[String]] => + val ds = TypedDataset.create(values) - check(stringFuncProp(substring(_, 5, 3), untyped.substring(_, 5, 3))) + val sparkResult = ds.toDF() + .select(untyped.substring($"a", 5, 3)) + .map(_.getAs[String](0)) + .collect() + .toVector + + val typed = ds + .select(substring(ds[String]('a), 5, 3)) + .collect() + .run() + .toVector + + typed ?= sparkResult + }) } test("trim") { val spark = session import spark.implicits._ + check(forAll { values: List[X1[String]] => + val ds = TypedDataset.create(values.map(x => X1(s" ${x.a} "))) + + val sparkResult = ds.toDF() + .select(untyped.trim($"a")) + .map(_.getAs[String](0)) + .collect() + .toVector + + val typed = ds + .select(trim(ds[String]('a))) + .collect() + .run() + .toVector - check(stringFuncProp(trim, untyped.trim)) + typed ?= sparkResult + }) } test("upper") { val spark = session import spark.implicits._ + check(forAll(Gen.listOf(Gen.alphaStr)) { values: List[String] => + val ds = TypedDataset.create(values.map(X1(_))) + + val sparkResult = ds.toDF() + .select(untyped.upper($"a")) + .map(_.getAs[String](0)) + .collect() + .toVector + + val typed = ds + .select(upper(ds[String]('a))) + .collect() + .run() + .toVector - check(stringFuncProp(upper, untyped.upper)) + typed ?= sparkResult + }) } - def stringFuncProp[A : Encoder](strFunc: TypedColumn[X1[String], String] => TypedColumn[X1[String], A], sparkFunc: Column => Column) = { - forAll { values: List[X1[String]] => - val ds = TypedDataset.create(values) + test("lower") { + val spark = session + import spark.implicits._ + check(forAll(Gen.listOf(Gen.alphaStr)) { values: List[String] => + val ds = TypedDataset.create(values.map(X1(_))) - val sparkResult: List[A] = ds.toDF() - .select(sparkFunc(untyped.col("a"))) - .map(_.getAs[A](0)) + val sparkResult = ds.toDF() + .select(untyped.lower($"a")) + .map(_.getAs[String](0)) .collect() - .toList + .toVector - val typed: List[A] = ds - .select(strFunc(ds[String]('a))) + val typed = ds + .select(lower(ds[String]('a))) .collect() .run() - .toList + .toVector typed ?= sparkResult + }) + } + + test("Empty vararg tests") { + import frameless.functions.aggregate._ + def prop[A : TypedEncoder, B: TypedEncoder](data: Vector[X2[A, B]]) = { + val ds = TypedDataset.create(data) + val frameless = ds.select(ds('a), concat(), ds('b), concatWs(":")).collect().run().toVector + val framelessAggr = ds.agg(first(ds('a)), concat(), concatWs("x"), litAggr(2)).collect().run().toVector + val scala = data.map(x => (x.a, "", x.b, "")) + val scalaAggr = if (data.nonEmpty) Vector((data.head.a, "", "", 2)) else Vector.empty + (frameless ?= scala).&&(framelessAggr ?= scalaAggr) } + + check(forAll(prop[Long, Long] _)) + check(forAll(prop[Option[Vector[Boolean]], Long] _)) } } diff --git a/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala b/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala index 2033c3c47..5108ed581 100644 --- a/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala +++ b/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala @@ -1,8 +1,9 @@ -package frameless.syntax +package frameless +package syntax -import frameless.{TypedDataset, TypedDatasetSuite, TypedEncoder, X2} import org.scalacheck.Prop import org.scalacheck.Prop._ +import frameless.functions.aggregate._ class FramelessSyntaxTests extends TypedDatasetSuite { // Hide the implicit SparkDelay[Job] on TypedDatasetSuite to avoid ambiguous implicits @@ -21,7 +22,28 @@ class FramelessSyntaxTests extends TypedDatasetSuite { } test("dataset typed - toTyped") { + def prop[A, B](data: Vector[X2[A, B]])( + implicit ev: TypedEncoder[X2[A, B]] + ): Prop = { + val dataset = session.createDataset(data)(TypedExpressionEncoder(ev)).typed + val dataframe = dataset.toDF() + + dataset.collect().run().toVector ?= dataframe.unsafeTyped[X2[A, B]].collect().run().toVector + } + check(forAll(prop[Int, String] _)) + check(forAll(prop[X1[Long], String] _)) } + test("frameless typed column and aggregate") { + def prop[A: TypedEncoder](a: A, b: A): Prop = { + val d = TypedDataset.create((a, b) :: Nil) + (d.select(d('_1).untyped.typedColumn).collect().run ?= d.select(d('_1)).collect().run).&&( + d.agg(first(d('_1))).collect().run() ?= d.agg(first(d('_1)).untyped.typedAggregate).collect().run() + ) + } + + check(forAll(prop[Int] _)) + check(forAll(prop[X1[Long]] _)) + } }