diff --git a/Cargo.lock b/Cargo.lock index 6d54d234e023d..08198cc49b72c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2666,6 +2666,7 @@ dependencies = [ "datafusion-functions", "datafusion-functions-nested", "log", + "percent-encoding", "rand 0.9.2", "sha1", "url", diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index f39b6992c326b..09959db41fe60 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -50,6 +50,7 @@ datafusion-expr = { workspace = true } datafusion-functions = { workspace = true, features = ["crypto_expressions"] } datafusion-functions-nested = { workspace = true } log = { workspace = true } +percent-encoding = "2.3.2" rand = { workspace = true } sha1 = "0.10" url = { workspace = true } diff --git a/datafusion/spark/src/function/url/mod.rs b/datafusion/spark/src/function/url/mod.rs index 82bf8a9e09616..657655429ebaa 100644 --- a/datafusion/spark/src/function/url/mod.rs +++ b/datafusion/spark/src/function/url/mod.rs @@ -21,9 +21,15 @@ use std::sync::Arc; pub mod parse_url; pub mod try_parse_url; +pub mod try_url_decode; +pub mod url_decode; +pub mod url_encode; make_udf_function!(parse_url::ParseUrl, parse_url); make_udf_function!(try_parse_url::TryParseUrl, try_parse_url); +make_udf_function!(try_url_decode::TryUrlDecode, try_url_decode); +make_udf_function!(url_decode::UrlDecode, url_decode); +make_udf_function!(url_encode::UrlEncode, url_encode); pub mod expr_fn { use datafusion_functions::export_functions; @@ -38,8 +44,17 @@ pub mod expr_fn { "Same as parse_url but returns NULL if an invalid URL is provided.", args )); + export_functions!((url_decode, "Decodes a URL-encoded string in ‘application/x-www-form-urlencoded’ format to its original format.", args)); + export_functions!((try_url_decode, "Same as url_decode but returns NULL if an invalid URL-encoded string is provided", args)); + export_functions!((url_encode, "Encodes a string into a URL-encoded string in ‘application/x-www-form-urlencoded’ format.", args)); } pub fn functions() -> Vec> { - vec![parse_url(), try_parse_url()] + vec![ + parse_url(), + try_parse_url(), + try_url_decode(), + url_decode(), + url_encode(), + ] } diff --git a/datafusion/spark/src/function/url/try_url_decode.rs b/datafusion/spark/src/function/url/try_url_decode.rs new file mode 100644 index 0000000000000..61440e7ff05a0 --- /dev/null +++ b/datafusion/spark/src/function/url/try_url_decode.rs @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use arrow::array::ArrayRef; +use arrow::datatypes::DataType; + +use datafusion_common::Result; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +use crate::function::url::url_decode::{spark_handled_url_decode, UrlDecode}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct TryUrlDecode { + signature: Signature, + url_decoder: UrlDecode, +} + +impl Default for TryUrlDecode { + fn default() -> Self { + Self::new() + } +} + +impl TryUrlDecode { + pub fn new() -> Self { + Self { + signature: Signature::string(1, Volatility::Immutable), + url_decoder: UrlDecode::new(), + } + } +} + +impl ScalarUDFImpl for TryUrlDecode { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "try_url_decode" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.url_decoder.return_type(arg_types) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + make_scalar_function(spark_try_url_decode, vec![])(&args) + } +} + +fn spark_try_url_decode(args: &[ArrayRef]) -> Result { + spark_handled_url_decode(args, |x| match x { + Err(_) => Ok(None), + result => result, + }) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::StringArray; + use datafusion_common::{cast::as_string_array, Result}; + + use super::*; + + #[test] + fn test_try_decode_error_handled() -> Result<()> { + let input = Arc::new(StringArray::from(vec![ + Some("http%3A%2F%2spark.apache.org"), // '%2s' is not a valid percent encoded character + // Valid cases + Some("https%3A%2F%2Fspark.apache.org"), + None, + ])); + + let expected = + StringArray::from(vec![None, Some("https://spark.apache.org"), None]); + + let result = spark_try_url_decode(&[input as ArrayRef])?; + let result = as_string_array(&result)?; + + assert_eq!(&expected, result); + Ok(()) + } +} diff --git a/datafusion/spark/src/function/url/url_decode.rs b/datafusion/spark/src/function/url/url_decode.rs new file mode 100644 index 0000000000000..520588bc19e9c --- /dev/null +++ b/datafusion/spark/src/function/url/url_decode.rs @@ -0,0 +1,259 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::borrow::Cow; +use std::sync::Arc; + +use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; +use datafusion_common::{exec_datafusion_err, exec_err, plan_err, Result}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use percent_encoding::percent_decode; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct UrlDecode { + signature: Signature, +} + +impl Default for UrlDecode { + fn default() -> Self { + Self::new() + } +} + +impl UrlDecode { + pub fn new() -> Self { + Self { + signature: Signature::string(1, Volatility::Immutable), + } + } + + /// Decodes a URL-encoded string from application/x-www-form-urlencoded format. + /// Although the `url::form_urlencoded` support decoding, it does not return error when the string is malformed + /// For example: "%2s" is not a valid percent-encoding, the `decode` function from `url::form_urlencoded` + /// will ignore this instead of return error + /// This function reproduce the same decoding process, plus an extra validation step + /// See + /// + /// # Arguments + /// + /// * `value` - The URL-encoded string to decode + /// + /// # Returns + /// + /// * `Ok(String)` - The decoded string + /// * `Err(DataFusionError)` - If the input is malformed or contains invalid UTF-8 + /// + fn decode(value: &str) -> Result { + // Check if the string has valid percent encoding + Self::validate_percent_encoding(value)?; + + let replaced = Self::replace_plus(value.as_bytes()); + percent_decode(&replaced) + .decode_utf8() + .map_err(|e| exec_datafusion_err!("Invalid UTF-8 sequence: {e}")) + .map(|parsed| parsed.into_owned()) + } + + /// Replace b'+' with b' ' + /// See: + fn replace_plus(input: &[u8]) -> Cow<'_, [u8]> { + match input.iter().position(|&b| b == b'+') { + None => Cow::Borrowed(input), + Some(first_position) => { + let mut replaced = input.to_owned(); + replaced[first_position] = b' '; + for byte in &mut replaced[first_position + 1..] { + if *byte == b'+' { + *byte = b' '; + } + } + Cow::Owned(replaced) + } + } + } + + /// Validate percent-encoding of the string + fn validate_percent_encoding(value: &str) -> Result<()> { + let bytes = value.as_bytes(); + let mut i = 0; + + while i < bytes.len() { + if bytes[i] == b'%' { + // Check if we have at least 2 more characters + if i + 2 >= bytes.len() { + return exec_err!( + "Invalid percent-encoding: incomplete sequence at position {}", + i + ); + } + + let hex1 = bytes[i + 1]; + let hex2 = bytes[i + 2]; + + if !hex1.is_ascii_hexdigit() || !hex2.is_ascii_hexdigit() { + return exec_err!( + "Invalid percent-encoding: invalid hex sequence '%{}{}' at position {}", + hex1 as char, + hex2 as char, + i + ); + } + i += 3; + } else { + i += 1; + } + } + Ok(()) + } +} + +impl ScalarUDFImpl for UrlDecode { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "url_decode" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return plan_err!( + "{} expects 1 argument, but got {}", + self.name(), + arg_types.len() + ); + } + // As the type signature is already checked, we can safely return the type of the first argument + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + make_scalar_function(spark_url_decode, vec![])(&args) + } +} + +/// Core implementation of URL decoding function. +/// +/// # Arguments +/// +/// * `args` - A slice containing exactly one ArrayRef with the URL-encoded strings to decode +/// +/// # Returns +/// +/// * `Ok(ArrayRef)` - A new array of the same type containing decoded strings +/// * `Err(DataFusionError)` - If validation fails or invalid arguments are provided +/// +fn spark_url_decode(args: &[ArrayRef]) -> Result { + spark_handled_url_decode(args, |x| x) +} + +pub fn spark_handled_url_decode( + args: &[ArrayRef], + err_handle_fn: impl Fn(Result>) -> Result>, +) -> Result { + if args.len() != 1 { + return exec_err!("`url_decode` expects 1 argument"); + } + + match &args[0].data_type() { + DataType::Utf8 => as_string_array(&args[0])? + .iter() + .map(|x| x.map(UrlDecode::decode).transpose()) + .map(&err_handle_fn) + .collect::>() + .map(|array| Arc::new(array) as ArrayRef), + DataType::LargeUtf8 => as_large_string_array(&args[0])? + .iter() + .map(|x| x.map(UrlDecode::decode).transpose()) + .map(&err_handle_fn) + .collect::>() + .map(|array| Arc::new(array) as ArrayRef), + DataType::Utf8View => as_string_view_array(&args[0])? + .iter() + .map(|x| x.map(UrlDecode::decode).transpose()) + .map(&err_handle_fn) + .collect::>() + .map(|array| Arc::new(array) as ArrayRef), + other => exec_err!("`url_decode`: Expr must be STRING, got {other:?}"), + } +} + +#[cfg(test)] +mod tests { + use arrow::array::StringArray; + use datafusion_common::Result; + + use super::*; + + #[test] + fn test_decode() -> Result<()> { + let input = Arc::new(StringArray::from(vec![ + Some("https%3A%2F%2Fspark.apache.org"), + Some("inva+lid://user:pass@host/file\\;param?query\\;p2"), + Some("inva lid://user:pass@host/file\\;param?query\\;p2"), + Some("%7E%21%40%23%24%25%5E%26%2A%28%29%5F%2B"), + Some("%E4%BD%A0%E5%A5%BD"), + Some(""), + None, + ])); + let expected = StringArray::from(vec![ + Some("https://spark.apache.org"), + Some("inva lid://user:pass@host/file\\;param?query\\;p2"), + Some("inva lid://user:pass@host/file\\;param?query\\;p2"), + Some("~!@#$%^&*()_+"), + Some("你好"), + Some(""), + None, + ]); + + let result = spark_url_decode(&[input as ArrayRef])?; + let result = as_string_array(&result)?; + + assert_eq!(&expected, result); + + Ok(()) + } + + #[test] + fn test_decode_error() -> Result<()> { + let input = Arc::new(StringArray::from(vec![ + Some("http%3A%2F%2spark.apache.org"), // '%2s' is not a valid percent encoded character + // Valid cases + Some("https%3A%2F%2Fspark.apache.org"), + None, + ])); + + let result = spark_url_decode(&[input]); + assert!(result.is_err_and(|e| e.to_string().contains("Invalid percent-encoding"))); + + Ok(()) + } +} diff --git a/datafusion/spark/src/function/url/url_encode.rs b/datafusion/spark/src/function/url/url_encode.rs new file mode 100644 index 0000000000000..9b37f0ac6a740 --- /dev/null +++ b/datafusion/spark/src/function/url/url_encode.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; +use datafusion_common::{exec_err, plan_err, Result}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use url::form_urlencoded::byte_serialize; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct UrlEncode { + signature: Signature, +} + +impl Default for UrlEncode { + fn default() -> Self { + Self::new() + } +} + +impl UrlEncode { + pub fn new() -> Self { + Self { + signature: Signature::string(1, Volatility::Immutable), + } + } + + /// Encode a string to application/x-www-form-urlencoded format. + /// + /// # Arguments + /// + /// * `value` - The string to encode + /// + /// # Returns + /// + /// * `Ok(String)` - The encoded string + /// + fn encode(value: &str) -> Result { + Ok(byte_serialize(value.as_bytes()).collect::()) + } +} + +impl ScalarUDFImpl for UrlEncode { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "url_encode" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return plan_err!( + "{} expects 1 argument, but got {}", + self.name(), + arg_types.len() + ); + } + // As the type signature is already checked, we can safely return the type of the first argument + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + make_scalar_function(spark_url_encode, vec![])(&args) + } +} + +/// Core implementation of URL encoding function. +/// +/// # Arguments +/// +/// * `args` - A slice containing exactly one ArrayRef with the strings to encode +/// +/// # Returns +/// +/// * `Ok(ArrayRef)` - A new array of the same type containing encoded strings +/// * `Err(DataFusionError)` - If invalid arguments are provided +/// +fn spark_url_encode(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("`url_encode` expects 1 argument"); + } + + match &args[0].data_type() { + DataType::Utf8 => as_string_array(&args[0])? + .iter() + .map(|x| x.map(UrlEncode::encode).transpose()) + .collect::>() + .map(|array| Arc::new(array) as ArrayRef), + DataType::LargeUtf8 => as_large_string_array(&args[0])? + .iter() + .map(|x| x.map(UrlEncode::encode).transpose()) + .collect::>() + .map(|array| Arc::new(array) as ArrayRef), + DataType::Utf8View => as_string_view_array(&args[0])? + .iter() + .map(|x| x.map(UrlEncode::encode).transpose()) + .collect::>() + .map(|array| Arc::new(array) as ArrayRef), + other => exec_err!("`url_encode`: Expr must be STRING, got {other:?}"), + } +} diff --git a/datafusion/sqllogictest/test_files/spark/url/try_url_decode.slt b/datafusion/sqllogictest/test_files/spark/url/try_url_decode.slt new file mode 100644 index 0000000000000..559c77af97e9a --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/url/try_url_decode.slt @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query T +SELECT try_url_decode('https%3A%2F%2Fspark.apache.org'); +---- +https://spark.apache.org + +# Test with LargeUtf8 +query T +SELECT try_url_decode(arrow_cast('https%3A%2F%2Fspark.apache.org', 'LargeUtf8')); +---- +https://spark.apache.org + +# Test with Utf8View +query T +SELECT try_url_decode(arrow_cast('https%3A%2F%2Fspark.apache.org', 'Utf8View')); +---- +https://spark.apache.org + +# Non-ASCII string +query T +SELECT try_url_decode('%E4%BD%A0%E5%A5%BD') +---- +你好 + +# Empty string +query T +SELECT try_url_decode(''); +---- +(empty) + +# Null value +query T +SELECT try_url_decode(NULL::string); +---- +NULL + +# Roundtrip with url_encode +query T +SELECT try_url_decode(url_encode('Spark SQL ~!@#$%^&*()')); +---- +Spark SQL ~!@#$%^&*() + +# Plus replacement +query T +SELECT try_url_decode('Spark+SQL%21'); +---- +Spark SQL! + +# Handled invalid percent encoding error +query T +SELECT try_url_decode('https%3%2F%2Fspark.apache.org'::string); +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/url/url_decode.slt b/datafusion/sqllogictest/test_files/spark/url/url_decode.slt index fa5028b647dc3..61399aa0ef2e7 100644 --- a/datafusion/sqllogictest/test_files/spark/url/url_decode.slt +++ b/datafusion/sqllogictest/test_files/spark/url/url_decode.slt @@ -15,13 +15,53 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT url_decode('https%3A%2F%2Fspark.apache.org'); -## PySpark 3.5.5 Result: {'url_decode(https%3A%2F%2Fspark.apache.org)': 'https://spark.apache.org', 'typeof(url_decode(https%3A%2F%2Fspark.apache.org))': 'string', 'typeof(https%3A%2F%2Fspark.apache.org)': 'string'} -#query -#SELECT url_decode('https%3A%2F%2Fspark.apache.org'::string); +query T +SELECT url_decode('https%3A%2F%2Fspark.apache.org'); +---- +https://spark.apache.org + +# Test with LargeUtf8 +query T +SELECT url_decode(arrow_cast('https%3A%2F%2Fspark.apache.org', 'LargeUtf8')); +---- +https://spark.apache.org + +# Test with Utf8View +query T +SELECT url_decode(arrow_cast('https%3A%2F%2Fspark.apache.org', 'Utf8View')); +---- +https://spark.apache.org + +# Non-ASCII string +query T +SELECT url_decode('%E4%BD%A0%E5%A5%BD') +---- +你好 + +# Empty string +query T +SELECT url_decode(''); +---- +(empty) + +# Null value +query T +SELECT url_decode(NULL::string); +---- +NULL + +# Roundtrip with url_encode +query T +SELECT url_decode(url_encode('Spark SQL ~!@#$%^&*()')); +---- +Spark SQL ~!@#$%^&*() + +# Plus replacement +query T +SELECT url_decode('Spark+SQL%21'); +---- +Spark SQL! + +# Invalid percent encoding case +query error DataFusion error: Execution error: Invalid percent\-encoding: invalid hex sequence '%3%' at position 5 +SELECT url_decode('https%3%2F%2Fspark.apache.org'::string); diff --git a/datafusion/sqllogictest/test_files/spark/url/url_encode.slt b/datafusion/sqllogictest/test_files/spark/url/url_encode.slt index 6aef87dcb4c0f..3d7a42f19384b 100644 --- a/datafusion/sqllogictest/test_files/spark/url/url_encode.slt +++ b/datafusion/sqllogictest/test_files/spark/url/url_encode.slt @@ -15,13 +15,19 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 +query T +SELECT url_encode('https://spark.apache.org'); +---- +https%3A%2F%2Fspark.apache.org -## Original Query: SELECT url_encode('https://spark.apache.org'); -## PySpark 3.5.5 Result: {'url_encode(https://spark.apache.org)': 'https%3A%2F%2Fspark.apache.org', 'typeof(url_encode(https://spark.apache.org))': 'string', 'typeof(https://spark.apache.org)': 'string'} -#query -#SELECT url_encode('https://spark.apache.org'::string); +# Test with LargeUtf8 +query T +SELECT url_encode(arrow_cast('https://spark.apache.org', 'LargeUtf8')); +---- +https%3A%2F%2Fspark.apache.org + +# Test with Utf8View +query T +SELECT url_encode(arrow_cast('https://spark.apache.org', 'Utf8View')); +---- +https%3A%2F%2Fspark.apache.org