Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ object FunctionRegistry {
expression[StringTrimLeft]("ltrim"),
expression[JsonTuple]("json_tuple"),
expression[FormatString]("printf"),
expression[ParseUrl]("parse_url"),
expression[RegExpExtract]("regexp_extract"),
expression[RegExpReplace]("regexp_replace"),
expression[StringRepeat]("repeat"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,19 @@

package org.apache.spark.sql.catalyst.expressions

import java.net.{MalformedURLException, URL}
import java.text.{DecimalFormat, DecimalFormatSymbols}
import java.util.{HashMap, Locale, Map => JMap}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}

import scala.util.matching.Regex

////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines expressions for string operations.
////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -611,6 +615,100 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
override def prettyName: String = "rpad"
}

/**
* Extracts a part from a URL
*/
@ExpressionDescription(
usage = "_FUNC_(url, partToExtract[, key]) - extracts a part from a URL",
extended = "Parts: HOST, PATH, QUERY, REF, PROTOCOL, AUTHORITY, FILE, USERINFO\n"
+ "key specifies which query to extract\n"
+ "Examples:\n"
+ " > SELECT _FUNC_('http://spark.apache.org/path?query=1', "
+ "'HOST') FROM src LIMIT 1;\n" + " 'spark.apache.org'\n"
+ " > SELECT _FUNC_('http://spark.apache.org/path?query=1', "
+ "'QUERY') FROM src LIMIT 1;\n" + " 'query=1'\n"
+ " > SELECT _FUNC_('http://spark.apache.org/path?query=1', "
+ "'QUERY', 'query') FROM src LIMIT 1;\n" + " '1'")
case class ParseUrl(children: Expression*)
extends Expression with ImplicitCastInputTypes with CodegenFallback {

override def nullable: Boolean = true

override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType)
override def dataType: DataType = StringType

override def checkInputDataTypes(): TypeCheckResult = {
if (children.size > 3 || children.size < 2) {
TypeCheckResult.TypeCheckFailure("parse_url function requires two or three arguments")
} else {
super[ImplicitCastInputTypes].checkInputDataTypes()
}
}

override def eval(input: InternalRow): Any = {
val urlStr = children(0).eval(input)
val part = children(1).eval(input)
if (urlStr == null || part == null) {
null
} else if (children.size == 2) {
try {
val url = new URL(urlStr.toString())
val partToExtract = part.toString()
if (partToExtract == "HOST") {
UTF8String.fromString(url.getHost())
} else if (partToExtract == "PATH") {
UTF8String.fromString(url.getPath())
} else if (partToExtract == "QUERY") {
UTF8String.fromString(url.getQuery())
} else if (partToExtract == "REF") {
UTF8String.fromString(url.getRef())
} else if (partToExtract == "PROTOCOL") {
UTF8String.fromString(url.getProtocol())
} else if (partToExtract == "FILE") {
UTF8String.fromString(url.getFile())
} else if (partToExtract == "AUTHORITY") {
UTF8String.fromString(url.getAuthority())
} else if (partToExtract == "USERINFO") {
UTF8String.fromString(url.getUserInfo())
} else {
null
}
} catch {
case ex: MalformedURLException => {
null
}
}
} else { // children.size == 3
val partToExtract = part.toString()
val key = children(2).eval(input)
if (key == null || partToExtract != "QUERY") {
null
} else {
try {
val url = new URL(urlStr.toString())
val query = url.getQuery()
if (query == null) {
null
} else {
val parttern = new Regex("[&^]?" + key.toString() + "=([^&]*)")
val parttern(value) = query
UTF8String.fromString(value)
}
} catch {
case ex: MalformedURLException => {
null
}
case ex: scala.MatchError => {
null
}
}
}
}
}

override def prettyName: String = "parse_url"
}

/**
* Returns the input formatted according do printf-style format strings
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -702,4 +702,43 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(FindInSet(Literal("abf"), Literal("abc,b,ab,c,def")), 0)
checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0)
}

test("ParseUrl") {
def checkParseUrl(expected: String, urlStr: String, partToExtract: String): Unit = {
checkEvaluation(
ParseUrl(Literal.create(urlStr, StringType), Literal.create(partToExtract, StringType)),
expected)
}
def checkParseUrlWithKey(expected: String, urlStr: String,
partToExtract: String, key: String): Unit = {
checkEvaluation(
ParseUrl(Literal.create(urlStr, StringType), Literal.create(partToExtract, StringType),
Literal.create(key, StringType)), expected)
}

checkParseUrl("spark.apache.org", "http://spark.apache.org/path?query=1", "HOST")
checkParseUrl("/path", "http://spark.apache.org/path?query=1", "PATH")
checkParseUrl("query=1", "http://spark.apache.org/path?query=1", "QUERY")
checkParseUrl("Ref", "http://spark.apache.org/path?query=1#Ref", "REF")
checkParseUrl("http", "http://spark.apache.org/path?query=1", "PROTOCOL")
checkParseUrl("/path?query=1", "http://spark.apache.org/path?query=1", "FILE")
checkParseUrl("spark.apache.org:8080", "http://spark.apache.org:8080/path?query=1", "AUTHORITY")
checkParseUrl("jian", "http://jian@spark.apache.org/path?query=1", "USERINFO")
checkParseUrlWithKey("1", "http://spark.apache.org/path?query=1", "QUERY", "query")

// Null checking
checkParseUrl(null, null, "HOST")
checkParseUrl(null, "http://spark.apache.org/path?query=1", null)
checkParseUrl(null, null, null)
checkParseUrl(null, "test", "HOST")
checkParseUrl(null, "http://spark.apache.org/path?query=1", "NO")
checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "HOST", "query")
checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", "quer")
checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", null)

// arguments checking
assert(ParseUrl(Literal("1")).checkInputDataTypes().isFailure)
assert(ParseUrl(Literal("1"), Literal("2"),
Literal("3"), Literal("4")).checkInputDataTypes().isFailure)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,20 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
Row("???hi", "hi???", "h", "h"))
}

test("string parse_url function") {
val df = Seq[String](("http://jian@spark.apache.org/path?query=1#Ref"))
.toDF("url")

checkAnswer(
df.selectExpr("parse_url(url, 'HOST')", "parse_url(url, 'PATH')",
"parse_url(url, 'QUERY')", "parse_url(url, 'REF')",
"parse_url(url, 'PROTOCOL')", "parse_url(url, 'FILE')",
"parse_url(url, 'AUTHORITY')", "parse_url(url, 'USERINFO')",
"parse_url(url, 'QUERY', 'query')"),
Row("spark.apache.org", "/path", "query=1", "Ref",
"http", "/path?query=1", "jian@spark.apache.org", "jian", "1"))
}

test("string repeat function") {
val df = Seq(("hi", 2)).toDF("a", "b")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ private[sql] class HiveSessionCatalog(
private val hiveFunctions = Seq(
"elt", "hash", "java_method", "histogram_numeric",
"map_keys", "map_values",
"parse_url", "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map",
"percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map",
"xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
"xpath_number", "xpath_short", "xpath_string",

Expand Down