diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3f9227a8ae00..6bf8162da3fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -283,6 +283,7 @@ object FunctionRegistry { expression[StringTrimLeft]("ltrim"), expression[JsonTuple]("json_tuple"), expression[FormatString]("printf"), + expression[ParseUrl]("parse_url"), expression[RegExpExtract]("regexp_extract"), expression[RegExpReplace]("regexp_replace"), expression[StringRepeat]("repeat"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 44ff7fda8ef4..b62f425b16a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -17,15 +17,19 @@ package org.apache.spark.sql.catalyst.expressions +import java.net.{MalformedURLException, URL} import java.text.{DecimalFormat, DecimalFormatSymbols} import java.util.{HashMap, Locale, Map => JMap} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import scala.util.matching.Regex + //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines expressions for string operations. //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -611,6 +615,100 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) override def prettyName: String = "rpad" } +/** + * Extracts a part from a URL + */ +@ExpressionDescription( + usage = "_FUNC_(url, partToExtract[, key]) - extracts a part from a URL", + extended = "Parts: HOST, PATH, QUERY, REF, PROTOCOL, AUTHORITY, FILE, USERINFO\n" + + "key specifies which query to extract\n" + + "Examples:\n" + + " > SELECT _FUNC_('http://spark.apache.org/path?query=1', " + + "'HOST') FROM src LIMIT 1;\n" + " 'spark.apache.org'\n" + + " > SELECT _FUNC_('http://spark.apache.org/path?query=1', " + + "'QUERY') FROM src LIMIT 1;\n" + " 'query=1'\n" + + " > SELECT _FUNC_('http://spark.apache.org/path?query=1', " + + "'QUERY', 'query') FROM src LIMIT 1;\n" + " '1'") +case class ParseUrl(children: Expression*) + extends Expression with ImplicitCastInputTypes with CodegenFallback { + + override def nullable: Boolean = true + + override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType) + override def dataType: DataType = StringType + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size > 3 || children.size < 2) { + TypeCheckResult.TypeCheckFailure("parse_url function requires two or three arguments") + } else { + super[ImplicitCastInputTypes].checkInputDataTypes() + } + } + + override def eval(input: InternalRow): Any = { + val urlStr = children(0).eval(input) + val part = children(1).eval(input) + if (urlStr == null || part == null) { + null + } else if (children.size == 2) { + try { + val url = new URL(urlStr.toString()) + val partToExtract = part.toString() + if (partToExtract == "HOST") { + UTF8String.fromString(url.getHost()) + } else if (partToExtract == "PATH") { + UTF8String.fromString(url.getPath()) + } else if (partToExtract == "QUERY") { + UTF8String.fromString(url.getQuery()) + } else if (partToExtract == "REF") { + UTF8String.fromString(url.getRef()) + } else if (partToExtract == "PROTOCOL") { + UTF8String.fromString(url.getProtocol()) + } else if (partToExtract == "FILE") { + UTF8String.fromString(url.getFile()) + } else if (partToExtract == "AUTHORITY") { + UTF8String.fromString(url.getAuthority()) + } else if (partToExtract == "USERINFO") { + UTF8String.fromString(url.getUserInfo()) + } else { + null + } + } catch { + case ex: MalformedURLException => { + null + } + } + } else { // children.size == 3 + val partToExtract = part.toString() + val key = children(2).eval(input) + if (key == null || partToExtract != "QUERY") { + null + } else { + try { + val url = new URL(urlStr.toString()) + val query = url.getQuery() + if (query == null) { + null + } else { + val parttern = new Regex("[&^]?" + key.toString() + "=([^&]*)") + val parttern(value) = query + UTF8String.fromString(value) + } + } catch { + case ex: MalformedURLException => { + null + } + case ex: scala.MatchError => { + null + } + } + } + } + } + + override def prettyName: String = "parse_url" +} + /** * Returns the input formatted according do printf-style format strings */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 29bf15bf524b..2987084fd7a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -702,4 +702,43 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(FindInSet(Literal("abf"), Literal("abc,b,ab,c,def")), 0) checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0) } + + test("ParseUrl") { + def checkParseUrl(expected: String, urlStr: String, partToExtract: String): Unit = { + checkEvaluation( + ParseUrl(Literal.create(urlStr, StringType), Literal.create(partToExtract, StringType)), + expected) + } + def checkParseUrlWithKey(expected: String, urlStr: String, + partToExtract: String, key: String): Unit = { + checkEvaluation( + ParseUrl(Literal.create(urlStr, StringType), Literal.create(partToExtract, StringType), + Literal.create(key, StringType)), expected) + } + + checkParseUrl("spark.apache.org", "http://spark.apache.org/path?query=1", "HOST") + checkParseUrl("/path", "http://spark.apache.org/path?query=1", "PATH") + checkParseUrl("query=1", "http://spark.apache.org/path?query=1", "QUERY") + checkParseUrl("Ref", "http://spark.apache.org/path?query=1#Ref", "REF") + checkParseUrl("http", "http://spark.apache.org/path?query=1", "PROTOCOL") + checkParseUrl("/path?query=1", "http://spark.apache.org/path?query=1", "FILE") + checkParseUrl("spark.apache.org:8080", "http://spark.apache.org:8080/path?query=1", "AUTHORITY") + checkParseUrl("jian", "http://jian@spark.apache.org/path?query=1", "USERINFO") + checkParseUrlWithKey("1", "http://spark.apache.org/path?query=1", "QUERY", "query") + + // Null checking + checkParseUrl(null, null, "HOST") + checkParseUrl(null, "http://spark.apache.org/path?query=1", null) + checkParseUrl(null, null, null) + checkParseUrl(null, "test", "HOST") + checkParseUrl(null, "http://spark.apache.org/path?query=1", "NO") + checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "HOST", "query") + checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", "quer") + checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", null) + + // arguments checking + assert(ParseUrl(Literal("1")).checkInputDataTypes().isFailure) + assert(ParseUrl(Literal("1"), Literal("2"), + Literal("3"), Literal("4")).checkInputDataTypes().isFailure) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 1de2d9b5adab..2b4c708b8b21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -212,6 +212,20 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row("???hi", "hi???", "h", "h")) } + test("string parse_url function") { + val df = Seq[String](("http://jian@spark.apache.org/path?query=1#Ref")) + .toDF("url") + + checkAnswer( + df.selectExpr("parse_url(url, 'HOST')", "parse_url(url, 'PATH')", + "parse_url(url, 'QUERY')", "parse_url(url, 'REF')", + "parse_url(url, 'PROTOCOL')", "parse_url(url, 'FILE')", + "parse_url(url, 'AUTHORITY')", "parse_url(url, 'USERINFO')", + "parse_url(url, 'QUERY', 'query')"), + Row("spark.apache.org", "/path", "query=1", "Ref", + "http", "/path?query=1", "jian@spark.apache.org", "jian", "1")) + } + test("string repeat function") { val df = Seq(("hi", 2)).toDF("a", "b") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index fa560a044b42..ea5729705884 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -240,7 +240,7 @@ private[sql] class HiveSessionCatalog( private val hiveFunctions = Seq( "elt", "hash", "java_method", "histogram_numeric", "map_keys", "map_values", - "parse_url", "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map", + "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map", "xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long", "xpath_number", "xpath_short", "xpath_string",