diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 50a90ae40497a..3723680ff5db2 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -683,11 +683,22 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi trait String2TrimExpression extends Expression with ImplicitCastInputTypes { + protected def srcStr: Expression + protected def trimStr: Option[Expression] + protected def direction: String + + override def children: Seq[Expression] = srcStr +: trimStr.toSeq override def dataType: DataType = StringType override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) + + override def sql: String = if (trimStr.isDefined) { + s"TRIM($direction ${trimStr.get.sql} FROM ${srcStr.sql})" + } else { + super.sql + } } object StringTrim { @@ -769,11 +780,8 @@ case class StringTrim( override def prettyName: String = "trim" - override def children: Seq[Expression] = if (trimStr.isDefined) { - srcStr :: trimStr.get :: Nil - } else { - srcStr :: Nil - } + override protected def direction: String = "BOTH" + override def eval(input: InternalRow): Any = { val srcString = srcStr.eval(input).asInstanceOf[UTF8String] if (srcString == null) { @@ -865,11 +873,7 @@ case class StringTrimLeft( override def prettyName: String = "ltrim" - override def children: Seq[Expression] = if (trimStr.isDefined) { - srcStr :: trimStr.get :: Nil - } else { - srcStr :: Nil - } + override protected def direction: String = "LEADING" override def eval(input: InternalRow): Any = { val srcString = srcStr.eval(input).asInstanceOf[UTF8String] @@ -964,11 +968,7 @@ case class StringTrimRight( override def prettyName: String = "rtrim" - override def children: Seq[Expression] = if (trimStr.isDefined) { - srcStr :: trimStr.get :: Nil - } else { - srcStr :: Nil - } + override protected def direction: String = "TRAILING" override def eval(input: InternalRow): Any = { val srcString = srcStr.eval(input).asInstanceOf[UTF8String] diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index fd6cc4d811045..8e33471e8b129 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -40,6 +40,7 @@ SELECT substring('Spark SQL' from -3); SELECT substring('Spark SQL' from 5 for 1); -- trim +SELECT trim(" xyz "), ltrim(" xyz "), rtrim(" xyz "); SELECT trim(BOTH 'xyz' FROM 'yxTomxx'), trim('xyz' FROM 'yxTomxx'); SELECT trim(BOTH 'x' FROM 'xxxbarxxx'), trim('x' FROM 'xxxbarxxx'); SELECT trim(LEADING 'xyz' FROM 'zzzytest'); diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/strings.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/strings.sql.out index 5f89c799498ac..e8a3a9b9731a6 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/strings.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/strings.sql.out @@ -977,7 +977,7 @@ struct -- !query SELECT trim(binary('\\000') from binary('\\000Tom\\000')) -- !query schema -struct +struct -- !query output Tom diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 042d332bdb5c2..43c18f5417110 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 33 +-- Number of queries: 34 -- !query @@ -204,10 +204,18 @@ struct k +-- !query +SELECT trim(" xyz "), ltrim(" xyz "), rtrim(" xyz ") +-- !query schema +struct +-- !query output +xyz xyz xyz + + -- !query SELECT trim(BOTH 'xyz' FROM 'yxTomxx'), trim('xyz' FROM 'yxTomxx') -- !query schema -struct +struct -- !query output Tom Tom @@ -215,7 +223,7 @@ Tom Tom -- !query SELECT trim(BOTH 'x' FROM 'xxxbarxxx'), trim('x' FROM 'xxxbarxxx') -- !query schema -struct +struct -- !query output bar bar @@ -223,7 +231,7 @@ bar bar -- !query SELECT trim(LEADING 'xyz' FROM 'zzzytest') -- !query schema -struct +struct -- !query output test @@ -231,7 +239,7 @@ test -- !query SELECT trim(LEADING 'xyz' FROM 'zzzytestxyz') -- !query schema -struct +struct -- !query output testxyz @@ -239,7 +247,7 @@ testxyz -- !query SELECT trim(LEADING 'xy' FROM 'xyxXxyLAST WORD') -- !query schema -struct +struct -- !query output XxyLAST WORD @@ -247,7 +255,7 @@ XxyLAST WORD -- !query SELECT trim(TRAILING 'xyz' FROM 'testxxzx') -- !query schema -struct +struct -- !query output test @@ -255,7 +263,7 @@ test -- !query SELECT trim(TRAILING 'xyz' FROM 'xyztestxxzx') -- !query schema -struct +struct -- !query output xyztest @@ -263,6 +271,6 @@ xyztest -- !query SELECT trim(TRAILING 'xy' FROM 'TURNERyxXxy') -- !query schema -struct +struct -- !query output TURNERyxX