diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index dc1d349f10f1b..715e448f822ab 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -277,7 +277,8 @@ class SqlParser extends AbstractSparkSQLParser { | SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } | COUNT ~ "(" ~> "*" <~ ")" ^^ { case _ => Count(Literal(1)) } | COUNT ~ "(" ~> expression <~ ")" ^^ { case exp => Count(exp) } - | COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } + | COUNT ~> "(" ~> DISTINCT ~> repsep(expression, ",") <~ ")" ^^ + { case exps => CountDistinct(exps) } | APPROXIMATE ~ COUNT ~ "(" ~ DISTINCT ~> expression <~ ")" ^^ { case exp => ApproxCountDistinct(exp) } | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 84ee3051eb682..7d35691b372e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -992,4 +992,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { "nulldata2 on nulldata1.value <=> nulldata2.value"), (1 to 2).map(i => Seq(i))) } + + test("Supporting multi column support for count(distinct ..) function in Spark SQL") { + val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil + val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) + rdd.registerTempTable("distinctData") + checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), 2) + } }