diff --git a/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala b/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala index 50721ed70..5aaf6c317 100644 --- a/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala +++ b/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala @@ -7,12 +7,15 @@ import org.apache.spark.sql.{functions => untyped} trait AggregateFunctions { + private def typedColumnToAggregate[A: TypedEncoder, T](a: TypedColumn[T, A]): TypedAggregate[T, A] = + new TypedAggregate[T,A](a.expr) + /** Creates a [[frameless.TypedColumn]] of literal value. If A is to be encoded using an Injection make * sure the injection instance is in scope. * * apache/spark */ - def lit[A: TypedEncoder, T](value: A): TypedColumn[T, A] = frameless.functions.lit(value) + def litAggr[A: TypedEncoder, T](value: A): TypedAggregate[T, A] = typedColumnToAggregate(lit(value)) /** Aggregate function: returns the number of items in a group. * diff --git a/dataset/src/test/scala/frameless/SchemaTests.scala b/dataset/src/test/scala/frameless/SchemaTests.scala index aaef496c2..a08ebcbcd 100644 --- a/dataset/src/test/scala/frameless/SchemaTests.scala +++ b/dataset/src/test/scala/frameless/SchemaTests.scala @@ -1,5 +1,6 @@ package frameless +import frameless.functions.lit import frameless.functions.aggregate._ import org.scalacheck.Prop import org.scalacheck.Prop._ diff --git a/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala b/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala index d9aa8ffb7..320dd2d49 100644 --- a/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala +++ b/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala @@ -168,6 +168,17 @@ class AggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Double] _)) } + test("litAggr") { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder](xs: List[A], b: B, c: C): Prop = { + val dataset = TypedDataset.create(xs) + val (r1, rb, rc, rcount) = dataset.agg(litAggr(1), litAggr(b), litAggr(c), count()).collect().run().head + (rcount ?= xs.size.toLong) && (r1 ?= 1) && (rb ?= b) && (rc ?= c) + } + + check(forAll(prop[Boolean, Int, String] _)) + check(forAll(prop[Option[Boolean], Vector[Option[Vector[Char]]], Long] _)) + } + test("count") { def prop[A: TypedEncoder](xs: List[A]): Prop = { val dataset = TypedDataset.create(xs)