1+ // Licensed to the Apache Software Foundation (ASF) under one
2+ // or more contributor license agreements. See the NOTICE file
3+ // distributed with this work for additional information
4+ // regarding copyright ownership. The ASF licenses this file
5+ // to you under the Apache License, Version 2.0 (the
6+ // "License"); you may not use this file except in compliance
7+ // with the License. You may obtain a copy of the License at
8+ //
9+ // http://www.apache.org/licenses/LICENSE-2.0
10+ //
11+ // Unless required by applicable law or agreed to in writing,
12+ // software distributed under the License is distributed on an
13+ // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+ // KIND, either express or implied. See the License for the
15+ // specific language governing permissions and limitations
16+ // under the License.
17+
18+ //! Utilities for casting scalar literals to different data types
19+ //!
20+ //! This module contains functions for casting ScalarValue literals
21+ //! to different data types, originally extracted from the optimizer's
22+ //! unwrap_cast module to be shared between logical and physical layers.
23+
24+ use std:: cmp:: Ordering ;
25+
26+ use arrow:: datatypes:: {
27+ DataType , TimeUnit , MAX_DECIMAL128_FOR_EACH_PRECISION ,
28+ MIN_DECIMAL128_FOR_EACH_PRECISION ,
29+ } ;
30+ use arrow:: temporal_conversions:: { MICROSECONDS , MILLISECONDS , NANOSECONDS } ;
31+
32+ use crate :: ScalarValue ;
33+
34+ /// Convert a literal value from one data type to another
35+ pub fn try_cast_literal_to_type (
36+ lit_value : & ScalarValue ,
37+ target_type : & DataType ,
38+ ) -> Option < ScalarValue > {
39+ let lit_data_type = lit_value. data_type ( ) ;
40+ if !is_supported_type ( & lit_data_type) || !is_supported_type ( target_type) {
41+ return None ;
42+ }
43+ if lit_value. is_null ( ) {
44+ // null value can be cast to any type of null value
45+ return ScalarValue :: try_from ( target_type) . ok ( ) ;
46+ }
47+ try_cast_numeric_literal ( lit_value, target_type)
48+ . or_else ( || try_cast_string_literal ( lit_value, target_type) )
49+ . or_else ( || try_cast_dictionary ( lit_value, target_type) )
50+ . or_else ( || try_cast_binary ( lit_value, target_type) )
51+ }
52+
53+ /// Returns true if unwrap_cast_in_comparison supports this data type
54+ pub fn is_supported_type ( data_type : & DataType ) -> bool {
55+ is_supported_numeric_type ( data_type)
56+ || is_supported_string_type ( data_type)
57+ || is_supported_dictionary_type ( data_type)
58+ || is_supported_binary_type ( data_type)
59+ }
60+
61+ /// Returns true if unwrap_cast_in_comparison support this numeric type
62+ pub fn is_supported_numeric_type ( data_type : & DataType ) -> bool {
63+ matches ! (
64+ data_type,
65+ DataType :: UInt8
66+ | DataType :: UInt16
67+ | DataType :: UInt32
68+ | DataType :: UInt64
69+ | DataType :: Int8
70+ | DataType :: Int16
71+ | DataType :: Int32
72+ | DataType :: Int64
73+ | DataType :: Decimal128 ( _, _)
74+ | DataType :: Timestamp ( _, _)
75+ )
76+ }
77+
78+ /// Returns true if unwrap_cast_in_comparison supports casting this value as a string
79+ pub fn is_supported_string_type ( data_type : & DataType ) -> bool {
80+ matches ! (
81+ data_type,
82+ DataType :: Utf8 | DataType :: LargeUtf8 | DataType :: Utf8View
83+ )
84+ }
85+
86+ /// Returns true if unwrap_cast_in_comparison supports casting this value as a dictionary
87+ pub fn is_supported_dictionary_type ( data_type : & DataType ) -> bool {
88+ matches ! ( data_type,
89+ DataType :: Dictionary ( _, inner) if is_supported_type( inner) )
90+ }
91+
92+ pub fn is_supported_binary_type ( data_type : & DataType ) -> bool {
93+ matches ! ( data_type, DataType :: Binary | DataType :: FixedSizeBinary ( _) )
94+ }
95+
96+ /// Convert a numeric value from one numeric data type to another
97+ pub fn try_cast_numeric_literal (
98+ lit_value : & ScalarValue ,
99+ target_type : & DataType ,
100+ ) -> Option < ScalarValue > {
101+ let lit_data_type = lit_value. data_type ( ) ;
102+ if !is_supported_numeric_type ( & lit_data_type)
103+ || !is_supported_numeric_type ( target_type)
104+ {
105+ return None ;
106+ }
107+
108+ let mul = match target_type {
109+ DataType :: UInt8
110+ | DataType :: UInt16
111+ | DataType :: UInt32
112+ | DataType :: UInt64
113+ | DataType :: Int8
114+ | DataType :: Int16
115+ | DataType :: Int32
116+ | DataType :: Int64 => 1_i128 ,
117+ DataType :: Timestamp ( _, _) => 1_i128 ,
118+ DataType :: Decimal128 ( _, scale) => 10_i128 . pow ( * scale as u32 ) ,
119+ _ => return None ,
120+ } ;
121+ let ( target_min, target_max) = match target_type {
122+ DataType :: UInt8 => ( u8:: MIN as i128 , u8:: MAX as i128 ) ,
123+ DataType :: UInt16 => ( u16:: MIN as i128 , u16:: MAX as i128 ) ,
124+ DataType :: UInt32 => ( u32:: MIN as i128 , u32:: MAX as i128 ) ,
125+ DataType :: UInt64 => ( u64:: MIN as i128 , u64:: MAX as i128 ) ,
126+ DataType :: Int8 => ( i8:: MIN as i128 , i8:: MAX as i128 ) ,
127+ DataType :: Int16 => ( i16:: MIN as i128 , i16:: MAX as i128 ) ,
128+ DataType :: Int32 => ( i32:: MIN as i128 , i32:: MAX as i128 ) ,
129+ DataType :: Int64 => ( i64:: MIN as i128 , i64:: MAX as i128 ) ,
130+ DataType :: Timestamp ( _, _) => ( i64:: MIN as i128 , i64:: MAX as i128 ) ,
131+ DataType :: Decimal128 ( precision, _) => (
132+ // Different precision for decimal128 can store different range of value.
133+ // For example, the precision is 3, the max of value is `999` and the min
134+ // value is `-999`
135+ MIN_DECIMAL128_FOR_EACH_PRECISION [ * precision as usize ] ,
136+ MAX_DECIMAL128_FOR_EACH_PRECISION [ * precision as usize ] ,
137+ ) ,
138+ _ => return None ,
139+ } ;
140+ let lit_value_target_type = match lit_value {
141+ ScalarValue :: Int8 ( Some ( v) ) => ( * v as i128 ) . checked_mul ( mul) ,
142+ ScalarValue :: Int16 ( Some ( v) ) => ( * v as i128 ) . checked_mul ( mul) ,
143+ ScalarValue :: Int32 ( Some ( v) ) => ( * v as i128 ) . checked_mul ( mul) ,
144+ ScalarValue :: Int64 ( Some ( v) ) => ( * v as i128 ) . checked_mul ( mul) ,
145+ ScalarValue :: UInt8 ( Some ( v) ) => ( * v as i128 ) . checked_mul ( mul) ,
146+ ScalarValue :: UInt16 ( Some ( v) ) => ( * v as i128 ) . checked_mul ( mul) ,
147+ ScalarValue :: UInt32 ( Some ( v) ) => ( * v as i128 ) . checked_mul ( mul) ,
148+ ScalarValue :: UInt64 ( Some ( v) ) => ( * v as i128 ) . checked_mul ( mul) ,
149+ ScalarValue :: TimestampSecond ( Some ( v) , _) => ( * v as i128 ) . checked_mul ( mul) ,
150+ ScalarValue :: TimestampMillisecond ( Some ( v) , _) => ( * v as i128 ) . checked_mul ( mul) ,
151+ ScalarValue :: TimestampMicrosecond ( Some ( v) , _) => ( * v as i128 ) . checked_mul ( mul) ,
152+ ScalarValue :: TimestampNanosecond ( Some ( v) , _) => ( * v as i128 ) . checked_mul ( mul) ,
153+ ScalarValue :: Decimal128 ( Some ( v) , _, scale) => {
154+ let lit_scale_mul = 10_i128 . pow ( * scale as u32 ) ;
155+ if mul >= lit_scale_mul {
156+ // Example:
157+ // lit is decimal(123,3,2)
158+ // target type is decimal(5,3)
159+ // the lit can be converted to the decimal(1230,5,3)
160+ ( * v) . checked_mul ( mul / lit_scale_mul)
161+ } else if ( * v) % ( lit_scale_mul / mul) == 0 {
162+ // Example:
163+ // lit is decimal(123000,10,3)
164+ // target type is int32: the lit can be converted to INT32(123)
165+ // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
166+ Some ( * v / ( lit_scale_mul / mul) )
167+ } else {
168+ // can't convert the lit decimal to the target data type
169+ None
170+ }
171+ }
172+ _ => None ,
173+ } ;
174+
175+ match lit_value_target_type {
176+ None => None ,
177+ Some ( value) => {
178+ if value >= target_min && value <= target_max {
179+ // the value casted from lit to the target type is in the range of target type.
180+ // return the target type of scalar value
181+ let result_scalar = match target_type {
182+ DataType :: Int8 => ScalarValue :: Int8 ( Some ( value as i8 ) ) ,
183+ DataType :: Int16 => ScalarValue :: Int16 ( Some ( value as i16 ) ) ,
184+ DataType :: Int32 => ScalarValue :: Int32 ( Some ( value as i32 ) ) ,
185+ DataType :: Int64 => ScalarValue :: Int64 ( Some ( value as i64 ) ) ,
186+ DataType :: UInt8 => ScalarValue :: UInt8 ( Some ( value as u8 ) ) ,
187+ DataType :: UInt16 => ScalarValue :: UInt16 ( Some ( value as u16 ) ) ,
188+ DataType :: UInt32 => ScalarValue :: UInt32 ( Some ( value as u32 ) ) ,
189+ DataType :: UInt64 => ScalarValue :: UInt64 ( Some ( value as u64 ) ) ,
190+ DataType :: Timestamp ( TimeUnit :: Second , tz) => {
191+ let value = cast_between_timestamp (
192+ & lit_data_type,
193+ & DataType :: Timestamp ( TimeUnit :: Second , tz. clone ( ) ) ,
194+ value,
195+ ) ;
196+ ScalarValue :: TimestampSecond ( value, tz. clone ( ) )
197+ }
198+ DataType :: Timestamp ( TimeUnit :: Millisecond , tz) => {
199+ let value = cast_between_timestamp (
200+ & lit_data_type,
201+ & DataType :: Timestamp ( TimeUnit :: Millisecond , tz. clone ( ) ) ,
202+ value,
203+ ) ;
204+ ScalarValue :: TimestampMillisecond ( value, tz. clone ( ) )
205+ }
206+ DataType :: Timestamp ( TimeUnit :: Microsecond , tz) => {
207+ let value = cast_between_timestamp (
208+ & lit_data_type,
209+ & DataType :: Timestamp ( TimeUnit :: Microsecond , tz. clone ( ) ) ,
210+ value,
211+ ) ;
212+ ScalarValue :: TimestampMicrosecond ( value, tz. clone ( ) )
213+ }
214+ DataType :: Timestamp ( TimeUnit :: Nanosecond , tz) => {
215+ let value = cast_between_timestamp (
216+ & lit_data_type,
217+ & DataType :: Timestamp ( TimeUnit :: Nanosecond , tz. clone ( ) ) ,
218+ value,
219+ ) ;
220+ ScalarValue :: TimestampNanosecond ( value, tz. clone ( ) )
221+ }
222+ DataType :: Decimal128 ( p, s) => {
223+ ScalarValue :: Decimal128 ( Some ( value) , * p, * s)
224+ }
225+ _ => {
226+ return None ;
227+ }
228+ } ;
229+ Some ( result_scalar)
230+ } else {
231+ None
232+ }
233+ }
234+ }
235+ }
236+
237+ pub fn try_cast_string_literal (
238+ lit_value : & ScalarValue ,
239+ target_type : & DataType ,
240+ ) -> Option < ScalarValue > {
241+ let string_value = lit_value. try_as_str ( ) ?. map ( |s| s. to_string ( ) ) ;
242+ let scalar_value = match target_type {
243+ DataType :: Utf8 => ScalarValue :: Utf8 ( string_value) ,
244+ DataType :: LargeUtf8 => ScalarValue :: LargeUtf8 ( string_value) ,
245+ DataType :: Utf8View => ScalarValue :: Utf8View ( string_value) ,
246+ _ => return None ,
247+ } ;
248+ Some ( scalar_value)
249+ }
250+
251+ /// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary
252+ pub fn try_cast_dictionary (
253+ lit_value : & ScalarValue ,
254+ target_type : & DataType ,
255+ ) -> Option < ScalarValue > {
256+ let lit_value_type = lit_value. data_type ( ) ;
257+ let result_scalar = match ( lit_value, target_type) {
258+ // Unwrap dictionary when inner type matches target type
259+ ( ScalarValue :: Dictionary ( _, inner_value) , _)
260+ if inner_value. data_type ( ) == * target_type =>
261+ {
262+ ( * * inner_value) . clone ( )
263+ }
264+ // Wrap type when target type is dictionary
265+ ( _, DataType :: Dictionary ( index_type, inner_type) )
266+ if * * inner_type == lit_value_type =>
267+ {
268+ ScalarValue :: Dictionary ( index_type. clone ( ) , Box :: new ( lit_value. clone ( ) ) )
269+ }
270+ _ => {
271+ return None ;
272+ }
273+ } ;
274+ Some ( result_scalar)
275+ }
276+
277+ /// Cast a timestamp value from one unit to another
278+ pub fn cast_between_timestamp ( from : & DataType , to : & DataType , value : i128 ) -> Option < i64 > {
279+ let value = value as i64 ;
280+ let from_scale = match from {
281+ DataType :: Timestamp ( TimeUnit :: Second , _) => 1 ,
282+ DataType :: Timestamp ( TimeUnit :: Millisecond , _) => MILLISECONDS ,
283+ DataType :: Timestamp ( TimeUnit :: Microsecond , _) => MICROSECONDS ,
284+ DataType :: Timestamp ( TimeUnit :: Nanosecond , _) => NANOSECONDS ,
285+ _ => return Some ( value) ,
286+ } ;
287+
288+ let to_scale = match to {
289+ DataType :: Timestamp ( TimeUnit :: Second , _) => 1 ,
290+ DataType :: Timestamp ( TimeUnit :: Millisecond , _) => MILLISECONDS ,
291+ DataType :: Timestamp ( TimeUnit :: Microsecond , _) => MICROSECONDS ,
292+ DataType :: Timestamp ( TimeUnit :: Nanosecond , _) => NANOSECONDS ,
293+ _ => return Some ( value) ,
294+ } ;
295+
296+ match from_scale. cmp ( & to_scale) {
297+ Ordering :: Less => value. checked_mul ( to_scale / from_scale) ,
298+ Ordering :: Greater => Some ( value / ( from_scale / to_scale) ) ,
299+ Ordering :: Equal => Some ( value) ,
300+ }
301+ }
302+
303+ pub fn try_cast_binary (
304+ lit_value : & ScalarValue ,
305+ target_type : & DataType ,
306+ ) -> Option < ScalarValue > {
307+ match ( lit_value, target_type) {
308+ ( ScalarValue :: Binary ( Some ( v) ) , DataType :: FixedSizeBinary ( n) )
309+ if v. len ( ) == * n as usize =>
310+ {
311+ Some ( ScalarValue :: FixedSizeBinary ( * n, Some ( v. clone ( ) ) ) )
312+ }
313+ _ => None ,
314+ }
315+ }
0 commit comments