diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 98a6da485e..6d1d2c1dc5 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -23,7 +23,7 @@ use crate::{ spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateDiff, - SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace, + SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace, SparkUrlEncode, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -198,6 +198,7 @@ fn all_scalar_functions() -> Vec> { Arc::new(ScalarUDF::new_from_impl(SparkMakeDate::default())), Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())), Arc::new(ScalarUDF::new_from_impl(SparkSizeFunc::default())), + Arc::new(ScalarUDF::new_from_impl(SparkUrlEncode::default())), ] } diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index abdd0cc89b..de012e27fa 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -18,7 +18,9 @@ mod contains; mod string_space; mod substring; +mod url_encode; pub use contains::SparkContains; pub use string_space::SparkStringSpace; pub use substring::SubstringExpr; +pub use url_encode::SparkUrlEncode; diff --git a/native/spark-expr/src/string_funcs/url_encode.rs b/native/spark-expr/src/string_funcs/url_encode.rs new file mode 100644 index 0000000000..765bcd2f00 --- /dev/null +++ b/native/spark-expr/src/string_funcs/url_encode.rs @@ -0,0 +1,302 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + as_dictionary_array, make_array, Array, ArrayData, ArrayRef, DictionaryArray, + GenericStringArray, OffsetSizeTrait, StringArray, +}; +use arrow::buffer::MutableBuffer; +use arrow::datatypes::{DataType, Int32Type}; +use datafusion::common::{exec_err, internal_datafusion_err, Result, ScalarValue}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use std::{any::Any, sync::Arc}; + +/// Spark-compatible URL encoding following application/x-www-form-urlencoded format. +/// This matches Java's URLEncoder.encode behavior used by Spark's UrlCodec.encode. +/// +/// Key behaviors: +/// - Spaces are encoded as '+' (not '%20') +/// - Alphanumeric characters (a-z, A-Z, 0-9) are not encoded +/// - Special characters '.', '-', '*', '_' are not encoded +/// - All other characters are percent-encoded using UTF-8 bytes +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkUrlEncode { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkUrlEncode { + fn default() -> Self { + Self::new() + } +} + +impl SparkUrlEncode { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkUrlEncode { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "url_encode" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(match &arg_types[0] { + DataType::Dictionary(key_type, _) => { + DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8)) + } + _ => DataType::Utf8, + }) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.len() != 1 { + return Err(internal_datafusion_err!( + "url_encode expects exactly one argument, got {}", + args.args.len() + )); + } + let args: [ColumnarValue; 1] = args + .args + .try_into() + .map_err(|_| internal_datafusion_err!("url_encode expects exactly one argument"))?; + spark_url_encode(&args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +pub fn spark_url_encode(args: &[ColumnarValue; 1]) -> Result { + match args { + [ColumnarValue::Array(array)] => { + let result = url_encode_array(array.as_ref())?; + Ok(ColumnarValue::Array(result)) + } + [ColumnarValue::Scalar(scalar)] => { + let result = url_encode_scalar(scalar)?; + Ok(ColumnarValue::Scalar(result)) + } + } +} + +fn url_encode_array(input: &dyn Array) -> Result { + match input.data_type() { + DataType::Utf8 => { + let array = input.as_any().downcast_ref::().unwrap(); + Ok(url_encode_string_array::(array)) + } + DataType::LargeUtf8 => { + let array = input + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(url_encode_string_array::(array)) + } + DataType::Dictionary(_, _) => { + let dict = as_dictionary_array::(input); + let values = url_encode_array(dict.values())?; + let result = DictionaryArray::try_new(dict.keys().clone(), values)?; + Ok(Arc::new(result)) + } + other => exec_err!("Unsupported input type for function 'url_encode': {other:?}"), + } +} + +fn url_encode_scalar(scalar: &ScalarValue) -> Result { + match scalar { + ScalarValue::Utf8(value) | ScalarValue::LargeUtf8(value) => { + let result = value.as_ref().map(|s| url_encode_string(s)); + Ok(ScalarValue::Utf8(result)) + } + ScalarValue::Null => Ok(ScalarValue::Utf8(None)), + other => exec_err!("Unsupported data type {other:?} for function `url_encode`"), + } +} + +fn url_encode_string_array( + input: &GenericStringArray, +) -> ArrayRef { + let array_len = input.len(); + let mut offsets = MutableBuffer::new((array_len + 1) * std::mem::size_of::()); + let mut values = MutableBuffer::new(input.values().len()); // reasonable initial capacity + let mut offset_so_far = OffsetSize::zero(); + let null_bit_buffer = input.to_data().nulls().map(|b| b.buffer().clone()); + + offsets.push(offset_so_far); + + for i in 0..array_len { + if !input.is_null(i) { + let encoded = url_encode_string(input.value(i)); + offset_so_far += OffsetSize::from_usize(encoded.len()).unwrap(); + values.extend_from_slice(encoded.as_bytes()); + } + offsets.push(offset_so_far); + } + + let data = unsafe { + ArrayData::new_unchecked( + GenericStringArray::::DATA_TYPE, + array_len, + None, + null_bit_buffer, + 0, + vec![offsets.into(), values.into()], + vec![], + ) + }; + make_array(data) +} + +fn url_encode_length(s: &str) -> usize { + let mut len = 0; + for byte in s.bytes() { + if should_encode(byte) { + if byte == b' ' { + len += 1; // space -> '+' + } else { + len += 3; // other -> %XX + } + } else { + len += 1; + } + } + len +} + +fn url_encode_string(s: &str) -> String { + let mut buf = Vec::with_capacity(url_encode_length(s)); + for byte in s.bytes() { + if !should_encode(byte) { + buf.push(byte); + } else if byte == b' ' { + buf.push(b'+'); + } else { + buf.push(b'%'); + buf.push(HEX_BYTES[(byte >> 4) as usize]); + buf.push(HEX_BYTES[(byte & 0x0F) as usize]); + } + } + + unsafe { String::from_utf8_unchecked(buf) } +} + +const HEX_BYTES: [u8; 16] = *b"0123456789ABCDEF"; + +/// Check if a byte should be encoded +/// Returns true for characters that need to be percent-encoded +fn should_encode(byte: u8) -> bool { + // Unreserved characters per RFC 3986 that are NOT encoded by URLEncoder: + // - Alphanumeric: A-Z, a-z, 0-9 + // - Special: '.', '-', '*', '_' + // Note: '~' is unreserved in RFC 3986 but IS encoded by Java URLEncoder + !matches!(byte, + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | + b'.' | b'-' | b'*' | b'_' + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::common::cast::as_string_array; + + #[test] + fn test_url_encode_basic() { + assert_eq!(url_encode_string("Hello World"), "Hello+World"); + assert_eq!(url_encode_string("foo=bar"), "foo%3Dbar"); + assert_eq!(url_encode_string("a+b"), "a%2Bb"); + assert_eq!(url_encode_string(""), ""); + } + + #[test] + fn test_url_encode_special_chars() { + assert_eq!(url_encode_string("?"), "%3F"); + assert_eq!(url_encode_string("&"), "%26"); + assert_eq!(url_encode_string("="), "%3D"); + assert_eq!(url_encode_string("#"), "%23"); + assert_eq!(url_encode_string("/"), "%2F"); + assert_eq!(url_encode_string("%"), "%25"); + } + + #[test] + fn test_url_encode_unreserved_chars() { + // These should NOT be encoded + assert_eq!(url_encode_string("abc123"), "abc123"); + assert_eq!(url_encode_string("ABC"), "ABC"); + assert_eq!(url_encode_string("."), "."); + assert_eq!(url_encode_string("-"), "-"); + assert_eq!(url_encode_string("*"), "*"); + assert_eq!(url_encode_string("_"), "_"); + } + + #[test] + fn test_url_encode_unicode() { + // UTF-8 multi-byte characters should be percent-encoded + assert_eq!(url_encode_string("cafe\u{0301}"), "cafe%CC%81"); + assert_eq!(url_encode_string("\u{00e9}"), "%C3%A9"); // é as single char + } + + #[test] + fn test_url_encode_array() { + let input = StringArray::from(vec![Some("Hello World"), Some("foo=bar"), None, Some("")]); + let args = ColumnarValue::Array(Arc::new(input)); + match spark_url_encode(&[args]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_string_array(&result).unwrap(); + assert_eq!(actual.value(0), "Hello+World"); + assert_eq!(actual.value(1), "foo%3Dbar"); + assert!(actual.is_null(2)); + assert_eq!(actual.value(3), ""); + } + _ => unreachable!(), + } + } + + #[test] + fn test_url_encode_scalar() { + let scalar = ScalarValue::Utf8(Some("Hello World".to_string())); + let result = url_encode_scalar(&scalar).unwrap(); + assert_eq!(result, ScalarValue::Utf8(Some("Hello+World".to_string()))); + + let null_scalar = ScalarValue::Utf8(None); + let null_result = url_encode_scalar(&null_scalar).unwrap(); + assert_eq!(null_result, ScalarValue::Utf8(None)); + } + + #[test] + fn test_url_encode_tilde() { + // ~ is unreserved in RFC 3986 but Java URLEncoder encodes it + assert_eq!(url_encode_string("~"), "%7E"); + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/statics.scala b/spark/src/main/scala/org/apache/comet/serde/statics.scala index 0737644ab9..f698965322 100644 --- a/spark/src/main/scala/org/apache/comet/serde/statics.scala +++ b/spark/src/main/scala/org/apache/comet/serde/statics.scala @@ -19,11 +19,12 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, UrlCodec} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { @@ -34,7 +35,28 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { : Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] = Map( ("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction( - "read_side_padding")) + "read_side_padding"), + ("encode", UrlCodec.getClass) -> CometUrlEncode) + + object CometUrlEncode extends CometExpressionSerde[StaticInvoke] { + override def convert( + expr: StaticInvoke, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + // StaticInvoke for url_encode may include a second child (the UTF-8 Charset object), + // which is not needed by the Rust backend — it always assumes UTF-8. + // We only convert the first child (the string data). + expr.children match { + case Seq(dataToEncode, _*) => + val childExpr = exprToProtoInternal(dataToEncode, inputs, binding) + val optExpr = scalarFunctionExprToProto("url_encode", childExpr) + optExprWithInfo(optExpr, expr, dataToEncode) + case _ => + withInfo(expr, "url_encode expected at least 1 argument but found none") + None + } + } + } override def convert( expr: StaticInvoke, diff --git a/spark/src/test/resources/sql-tests/expressions/string/url_encode.sql b/spark/src/test/resources/sql-tests/expressions/string/url_encode.sql new file mode 100644 index 0000000000..dad356cae2 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/url_encode.sql @@ -0,0 +1,59 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Note: UrlEncode is a RuntimeReplaceable expression that delegates to UrlCodec.encode + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +statement +CREATE TABLE test_url_encode(url_string string) USING parquet + +statement +INSERT INTO test_url_encode VALUES ('Hello World'), ('foo=bar&baz=qux'), ('https://example.com/path?query=value'), (''), (NULL), ('no encoding needed'), ('100%'), ('a+b=c'), ('special!@#$%^&*()'), ('space test') + +query +SELECT url_string, url_encode(url_string) FROM test_url_encode + +query +SELECT url_encode('Hello World'), url_encode('foo=bar'), url_encode(''), url_encode(NULL) + +query +SELECT url_encode('?'), url_encode('&'), url_encode('='), url_encode('#'), url_encode('/') + +query +SELECT url_encode('%'), url_encode('100%'), url_encode('already%20encoded') + +query +SELECT url_encode(' '), url_encode('+'), url_encode('a b+c') + +statement +CREATE TABLE test_url_encode_unicode(url_string string) USING parquet + +statement +INSERT INTO test_url_encode_unicode VALUES ('café'), ('hello世界'), ('日本語'), ('emoji😀test'), ('తెలుగు'), (NULL) + +query +SELECT url_string, url_encode(url_string) FROM test_url_encode_unicode + +query +SELECT url_encode('Hello World!'), url_encode('email@example.com'), url_encode('price=$100') + +query +SELECT url_encode(':'), url_encode('/'), url_encode('?'), url_encode('#'), url_encode('['), url_encode(']'), url_encode('@') + +query +SELECT url_encode('!'), url_encode('$'), url_encode('&'), url_encode("'"), url_encode('('), url_encode(')'), url_encode('*'), url_encode('+'), url_encode(','), url_encode(';'), url_encode('=') \ No newline at end of file