Skip to content

Commit

Permalink
cleanup + fix typo
Browse files Browse the repository at this point in the history
Signed-off-by: Devan <devandbenz@gmail.com>
  • Loading branch information
devanbenz committed Aug 14, 2024
1 parent 57f3161 commit a0c78de
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions datafusion/functions/src/unicode/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
use std::any::Any;
use std::sync::Arc;

use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, StringArray};
use arrow::array::{
ArrayAccessor, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait,
};
use arrow::datatypes::DataType;
use hashbrown::HashMap;
use unicode_segmentation::UnicodeSegmentation;

use crate::utils::{make_scalar_function, utf8_to_str_type};
use datafusion_common::cast::as_generic_string_array;
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
Expand Down Expand Up @@ -81,20 +84,20 @@ fn invoke_translate(args: &[ArrayRef]) -> Result<ArrayRef> {
DataType::Utf8View => {
let string_array = args[0].as_string_view();
let from_array = args[1].as_string::<i32>();
let to_array = args[1].as_string::<i32>();
translate::<_, _>(string_array, from_array, to_array)
let to_array = args[2].as_string::<i32>();
translate::<i32, _, _>(string_array, from_array, to_array)
}
DataType::Utf8 => {
let string_array = args[0].as_string::<i32>();
let from_array = args[1].as_string::<i32>();
let to_array = args[1].as_string::<i32>();
translate::<_, _>(string_array, from_array, to_array)
let to_array = args[2].as_string::<i32>();
translate::<i32, _, _>(string_array, from_array, to_array)
}
DataType::LargeUtf8 => {
let string_array = args[0].as_string::<i64>();
let from_array = args[1].as_string::<i64>();
let to_array = args[1].as_string::<i64>();
translate::<_, _>(string_array, from_array, to_array)
let to_array = args[2].as_string::<i64>();
translate::<i64, _, _>(string_array, from_array, to_array)
}
other => {
exec_err!("Unsupported data type {other:?} for function translate")
Expand All @@ -104,18 +107,23 @@ fn invoke_translate(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted.
/// translate('12345', '143', 'ax') = 'a2x5'
fn translate<'a, V, B>(string_array: V, from_array: B, to_array: B) -> Result<ArrayRef>
fn translate<'a, T: OffsetSizeTrait, V, B>(
string_array: V,
from_array: B,
to_array: B,
) -> Result<ArrayRef>
where
V: ArrayAccessor<Item = &'a str>,
B: ArrayAccessor<Item = &'a str>,
{
let string_array_iter = ArrayIter::new(string_array);
let from_array_iter = ArrayIter::new(from_array);
let to_array_iter = ArrayIter::new(to_array);
let string_array_iter = as_generic_string_array::<T>(&string_array)?;
let from_array_iter = as_generic_string_array::<T>(&from_array)?;
let to_array_iter = as_generic_string_array::<T>(&to_array)?;

let result = string_array_iter
.zip(from_array_iter)
.zip(to_array_iter)
.iter()
.zip(from_array_iter.iter())
.zip(to_array_iter.iter())
.map(|((string, from), to)| match (string, from, to) {
(Some(string), Some(from), Some(to)) => {
// create a hashmap of [char, index] to change from O(n) to O(1) for from list
Expand Down Expand Up @@ -144,7 +152,7 @@ where
}
_ => None,
})
.collect::<StringArray>();
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}
Expand Down

0 comments on commit a0c78de

Please sign in to comment.