From b50c33620bb876f3213914fd41612de9945861cf Mon Sep 17 00:00:00 2001 From: Zhenchi Date: Wed, 13 Nov 2024 10:29:44 +0000 Subject: [PATCH 1/4] feat: add distance functions Signed-off-by: Zhenchi --- Cargo.lock | 24 +- src/common/function/Cargo.toml | 1 + src/common/function/src/function_registry.rs | 4 + src/common/function/src/scalars.rs | 1 + src/common/function/src/scalars/vector.rs | 31 ++ .../function/src/scalars/vector/distance.rs | 469 ++++++++++++++++++ src/datatypes/src/value.rs | 6 +- .../common/types/vector/vector.result | 174 +++++++ .../standalone/common/types/vector/vector.sql | 44 ++ 9 files changed, 744 insertions(+), 10 deletions(-) create mode 100644 src/common/function/src/scalars/vector.rs create mode 100644 src/common/function/src/scalars/vector/distance.rs diff --git a/Cargo.lock b/Cargo.lock index a544da4087af..ac66148f625b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1041,7 +1041,7 @@ dependencies = [ "bitflags 2.6.0", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.12.1", "lazy_static", "lazycell", "log", @@ -2087,6 +2087,7 @@ dependencies = [ "serde", "serde_json", "session", + "simsimd", "snafu 0.8.5", "sql", "statrs", @@ -5079,7 +5080,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.4.10", + "socket2 0.5.7", "tokio", "tower-service", "tracing", @@ -6069,7 +6070,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -8810,7 +8811,7 @@ checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes", "heck 0.5.0", - "itertools 0.10.5", + "itertools 0.12.1", "log", "multimap", "once_cell", @@ -8862,7 +8863,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" dependencies = [ "anyhow", - "itertools 0.10.5", + "itertools 0.12.1", "proc-macro2", "quote", "syn 2.0.79", @@ -9024,7 +9025,7 @@ dependencies = [ "indoc", "libc", "memoffset 0.9.1", - "parking_lot 0.11.2", + "parking_lot 0.12.3", "portable-atomic", "pyo3-build-config", "pyo3-ffi", @@ -11186,6 +11187,15 @@ dependencies = [ "time", ] +[[package]] +name = "simsimd" +version = "4.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efc843bc8f12d9c8e6b734a0fe8918fc497b42f6ae0f347dbfdad5b5138ab9b4" +dependencies = [ + "cc", +] + [[package]] name = "siphasher" version = "0.3.11" @@ -13969,7 +13979,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] diff --git a/src/common/function/Cargo.toml b/src/common/function/Cargo.toml index 6c1ecc2d381e..cb876b352dd9 100644 --- a/src/common/function/Cargo.toml +++ b/src/common/function/Cargo.toml @@ -41,6 +41,7 @@ s2 = { version = "0.0.12", optional = true } serde.workspace = true serde_json.workspace = true session.workspace = true +simsimd = "4" snafu.workspace = true sql.workspace = true statrs = "0.16" diff --git a/src/common/function/src/function_registry.rs b/src/common/function/src/function_registry.rs index 46af3b761072..04d68a93d85e 100644 --- a/src/common/function/src/function_registry.rs +++ b/src/common/function/src/function_registry.rs @@ -27,6 +27,7 @@ use crate::scalars::matches::MatchesFunction; use crate::scalars::math::MathFunction; use crate::scalars::numpy::NumpyFunction; use crate::scalars::timestamp::TimestampFunction; +use crate::scalars::vector::VectorFunction; use crate::system::SystemFunction; use crate::table::TableFunction; @@ -120,6 +121,9 @@ pub static FUNCTION_REGISTRY: Lazy> = Lazy::new(|| { // Json related functions JsonFunction::register(&function_registry); + // Vector related functions + VectorFunction::register(&function_registry); + // Geo functions #[cfg(feature = "geo")] crate::scalars::geo::GeoFunctions::register(&function_registry); diff --git a/src/common/function/src/scalars.rs b/src/common/function/src/scalars.rs index f60cf2b0d98b..52a238273d99 100644 --- a/src/common/function/src/scalars.rs +++ b/src/common/function/src/scalars.rs @@ -21,6 +21,7 @@ pub mod json; pub mod matches; pub mod math; pub mod numpy; +pub mod vector; #[cfg(test)] pub(crate) mod test; diff --git a/src/common/function/src/scalars/vector.rs b/src/common/function/src/scalars/vector.rs new file mode 100644 index 000000000000..67b812fd09f0 --- /dev/null +++ b/src/common/function/src/scalars/vector.rs @@ -0,0 +1,31 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod distance; + +use std::sync::Arc; + +use distance::{CosDistanceFunction, DotProductFunction, L2SqDistanceFunction}; + +use crate::function_registry::FunctionRegistry; + +pub(crate) struct VectorFunction; + +impl VectorFunction { + pub fn register(registry: &FunctionRegistry) { + registry.register(Arc::new(CosDistanceFunction)); + registry.register(Arc::new(DotProductFunction)); + registry.register(Arc::new(L2SqDistanceFunction)); + } +} diff --git a/src/common/function/src/scalars/vector/distance.rs b/src/common/function/src/scalars/vector/distance.rs new file mode 100644 index 000000000000..8bf39d62c684 --- /dev/null +++ b/src/common/function/src/scalars/vector/distance.rs @@ -0,0 +1,469 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::borrow::Cow; +use std::fmt::Display; +use std::sync::Arc; + +use common_query::error::{InvalidFuncArgsSnafu, Result}; +use common_query::prelude::Signature; +use datatypes::prelude::ConcreteDataType; +use datatypes::scalars::ScalarVectorBuilder; +use datatypes::value::ValueRef; +use datatypes::vectors::{Float32VectorBuilder, MutableVector, Vector, VectorRef}; +use snafu::ensure; + +use crate::function::{Function, FunctionContext}; +use crate::helper; + +macro_rules! define_distance_function { + ($StructName:ident, $display_name:expr, $similarity_method:ident) => { + + /// A function calculates the distance between two vectors. + + #[derive(Debug, Clone, Default)] + pub struct $StructName; + + impl Function for $StructName { + fn name(&self) -> &str { + $display_name + } + + fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result { + Ok(ConcreteDataType::float32_datatype()) + } + + fn signature(&self) -> Signature { + helper::one_of_sigs2( + vec![ + ConcreteDataType::string_datatype(), + ConcreteDataType::binary_datatype(), + ], + vec![ + ConcreteDataType::string_datatype(), + ConcreteDataType::binary_datatype(), + ], + ) + } + + fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + ensure!( + columns.len() == 2, + InvalidFuncArgsSnafu { + err_msg: format!( + "The length of the args is not correct, expect exactly two, have: {}", + columns.len() + ), + } + ); + let arg0 = &columns[0]; + let arg1 = &columns[1]; + + let size = arg0.len(); + let mut result = Float32VectorBuilder::with_capacity(size); + if size == 0 { + return Ok(result.to_vector()); + } + + let arg0_const = parse_if_constant_string(arg0)?; + let arg1_const = parse_if_constant_string(arg1)?; + + for i in 0..size { + let vec0 = match arg0_const.as_ref() { + Some(a) => Some(Cow::Borrowed(a.as_slice())), + None => as_vector(arg0.get_ref(i))?, + }; + let vec1 = match arg1_const.as_ref() { + Some(b) => Some(Cow::Borrowed(b.as_slice())), + None => as_vector(arg1.get_ref(i))?, + }; + + if let (Some(vec0), Some(vec1)) = (vec0, vec1) { + ensure!( + vec0.len() == vec1.len(), + InvalidFuncArgsSnafu { + err_msg: format!( + "The length of the vectors must match to calculate distance, have: {} vs {}", + vec0.len(), + vec1.len() + ), + } + ); + + let f = ::$similarity_method; + // Safe: checked if the length of the vectors match + let d = f(vec0.as_ref(), vec1.as_ref()).unwrap(); + result.push(Some(d as f32)); + } else { + result.push_null(); + } + } + + return Ok(result.to_vector()); + } + } + + impl Display for $StructName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", $display_name.to_ascii_uppercase()) + } + } + } +} + +define_distance_function!(CosDistanceFunction, "cos_distance", cos); +define_distance_function!(L2SqDistanceFunction, "l2sq_distance", l2sq); +define_distance_function!(DotProductFunction, "dot_product", dot); + +/// Parse a vector value if the value is a constant string. +fn parse_if_constant_string(arg: &Arc) -> Result>> { + if !arg.is_const() { + return Ok(None); + } + if arg.data_type() != ConcreteDataType::string_datatype() { + return Ok(None); + } + arg.get_ref(0) + .as_string() + .unwrap() // Safe: checked if it is a string + .map(parse_vector_from_string) + .transpose() +} + +/// Convert a value to a vector value. +/// Supported data types are binary and string. +fn as_vector(arg: ValueRef<'_>) -> Result>> { + match arg.data_type() { + ConcreteDataType::Binary(_) => arg + .as_binary() + .unwrap() // Safe: checked if it is a binary + .map(|bytes| Ok(Cow::Borrowed(binary_as_vector(bytes)?))) + .transpose(), + ConcreteDataType::String(_) => arg + .as_string() + .unwrap() // Safe: checked if it is a string + .map(|s| Ok(Cow::Owned(parse_vector_from_string(s)?))) + .transpose(), + ConcreteDataType::Null(_) => Ok(None), + _ => InvalidFuncArgsSnafu { + err_msg: format!("Unsupported data type: {:?}", arg.data_type()), + } + .fail(), + } +} + +/// Convert a u8 slice to a vector value. +fn binary_as_vector(bytes: &[u8]) -> Result<&[f32]> { + if bytes.len() % 4 != 0 { + return InvalidFuncArgsSnafu { + err_msg: format!("Invalid binary length of vector: {}", bytes.len()), + } + .fail(); + } + + unsafe { + let num_floats = bytes.len() / 4; + let floats: &[f32] = std::slice::from_raw_parts(bytes.as_ptr() as *const f32, num_floats); + Ok(floats) + } +} + +/// Parse a string to a vector value. +/// Valid inputs are strings like "[1.0, 2.0, 3.0]". +fn parse_vector_from_string(s: &str) -> Result> { + let trimmed = s.trim(); + if !trimmed.starts_with('[') || !trimmed.ends_with(']') { + return InvalidFuncArgsSnafu { + err_msg: format!( + "Failed to parse {s} to Vector value: not properly enclosed in brackets" + ), + } + .fail(); + } + let content = trimmed[1..trimmed.len() - 1].trim(); + if content.is_empty() { + return Ok(Vec::new()); + } + + content + .split(',') + .map(|s| s.trim().parse::()) + .collect::>() + .map_err(|e| { + InvalidFuncArgsSnafu { + err_msg: format!("Failed to parse {s} to Vector value: {e}"), + } + .build() + }) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use datatypes::vectors::{BinaryVector, ConstantVector, StringVector}; + + use super::*; + + #[test] + fn test_distance_string_string() { + let funcs = [ + Box::new(CosDistanceFunction {}) as Box, + Box::new(L2SqDistanceFunction {}) as Box, + Box::new(DotProductFunction {}) as Box, + ]; + + for func in funcs { + let vec1 = Arc::new(StringVector::from(vec![ + Some("[0.0, 1.0]"), + Some("[1.0, 0.0]"), + None, + Some("[1.0, 0.0]"), + ])) as VectorRef; + let vec2 = Arc::new(StringVector::from(vec![ + Some("[0.0, 1.0]"), + Some("[0.0, 1.0]"), + Some("[0.0, 1.0]"), + None, + ])) as VectorRef; + + let result = func + .eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()]) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(result.get(2).is_null()); + assert!(result.get(3).is_null()); + + let result = func + .eval(FunctionContext::default(), &[vec2, vec1]) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(result.get(2).is_null()); + assert!(result.get(3).is_null()); + } + } + + #[test] + fn test_distance_binary_binary() { + let funcs = [ + Box::new(CosDistanceFunction {}) as Box, + Box::new(L2SqDistanceFunction {}) as Box, + Box::new(DotProductFunction {}) as Box, + ]; + + for func in funcs { + let vec1 = Arc::new(BinaryVector::from(vec![ + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + Some(vec![0, 0, 128, 63, 0, 0, 0, 0]), + None, + Some(vec![0, 0, 128, 63, 0, 0, 0, 0]), + ])) as VectorRef; + let vec2 = Arc::new(BinaryVector::from(vec![ + // [0.0, 1.0] + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + None, + ])) as VectorRef; + + let result = func + .eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()]) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(result.get(2).is_null()); + assert!(result.get(3).is_null()); + + let result = func + .eval(FunctionContext::default(), &[vec2, vec1]) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(result.get(2).is_null()); + assert!(result.get(3).is_null()); + } + } + + #[test] + fn test_distance_string_binary() { + let funcs = [ + Box::new(CosDistanceFunction {}) as Box, + Box::new(L2SqDistanceFunction {}) as Box, + Box::new(DotProductFunction {}) as Box, + ]; + + for func in funcs { + let vec1 = Arc::new(StringVector::from(vec![ + Some("[0.0, 1.0]"), + Some("[1.0, 0.0]"), + None, + Some("[1.0, 0.0]"), + ])) as VectorRef; + let vec2 = Arc::new(BinaryVector::from(vec![ + // [0.0, 1.0] + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + None, + ])) as VectorRef; + + let result = func + .eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()]) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(result.get(2).is_null()); + assert!(result.get(3).is_null()); + + let result = func + .eval(FunctionContext::default(), &[vec2, vec1]) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(result.get(2).is_null()); + assert!(result.get(3).is_null()); + } + } + + #[test] + fn test_distance_const_string() { + let funcs = [ + Box::new(CosDistanceFunction {}) as Box, + Box::new(L2SqDistanceFunction {}) as Box, + Box::new(DotProductFunction {}) as Box, + ]; + + for func in funcs { + let const_str = Arc::new(ConstantVector::new( + Arc::new(StringVector::from(vec!["[0.0, 1.0]"])), + 4, + )); + + let vec1 = Arc::new(StringVector::from(vec![ + Some("[0.0, 1.0]"), + Some("[1.0, 0.0]"), + None, + Some("[1.0, 0.0]"), + ])) as VectorRef; + let vec2 = Arc::new(BinaryVector::from(vec![ + // [0.0, 1.0] + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + None, + ])) as VectorRef; + + let result = func + .eval( + FunctionContext::default(), + &[const_str.clone(), vec1.clone()], + ) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(result.get(2).is_null()); + assert!(!result.get(3).is_null()); + + let result = func + .eval( + FunctionContext::default(), + &[vec1.clone(), const_str.clone()], + ) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(result.get(2).is_null()); + assert!(!result.get(3).is_null()); + + let result = func + .eval( + FunctionContext::default(), + &[const_str.clone(), vec2.clone()], + ) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(!result.get(2).is_null()); + assert!(result.get(3).is_null()); + + let result = func + .eval( + FunctionContext::default(), + &[vec2.clone(), const_str.clone()], + ) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(!result.get(2).is_null()); + assert!(result.get(3).is_null()); + } + } + + #[test] + fn test_invalid_vector_length() { + let funcs = [ + Box::new(CosDistanceFunction {}) as Box, + Box::new(L2SqDistanceFunction {}) as Box, + Box::new(DotProductFunction {}) as Box, + ]; + + for func in funcs { + let vec1 = Arc::new(StringVector::from(vec!["[1.0]"])) as VectorRef; + let vec2 = Arc::new(StringVector::from(vec!["[1.0, 1.0]"])) as VectorRef; + let result = func.eval(FunctionContext::default(), &[vec1, vec2]); + assert!(result.is_err()); + + let vec1 = Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63]])) as VectorRef; + let vec2 = + Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63, 0, 0, 0, 64]])) as VectorRef; + let result = func.eval(FunctionContext::default(), &[vec1, vec2]); + assert!(result.is_err()); + } + } + + #[test] + fn test_parse_vector_from_string() { + let result = parse_vector_from_string("[1.0, 2.0, 3.0]").unwrap(); + assert_eq!(result, vec![1.0, 2.0, 3.0]); + + let result = parse_vector_from_string("[]").unwrap(); + assert_eq!(result, Vec::::new()); + + let result = parse_vector_from_string("[1.0, a, 3.0]"); + assert!(result.is_err()); + } + + #[test] + fn test_binary_as_vector() { + let bytes = [0, 0, 128, 63]; + let result = binary_as_vector(&bytes).unwrap(); + assert_eq!(result, &[1.0]); + + let invalid_bytes = [0, 0, 128]; + let result = binary_as_vector(&invalid_bytes); + assert!(result.is_err()); + } +} diff --git a/src/datatypes/src/value.rs b/src/datatypes/src/value.rs index d0f36139ca4a..b57b364cf3e5 100644 --- a/src/datatypes/src/value.rs +++ b/src/datatypes/src/value.rs @@ -1089,7 +1089,7 @@ macro_rules! impl_as_for_value_ref { }; } -impl ValueRef<'_> { +impl<'a> ValueRef<'a> { define_data_type_func!(ValueRef); /// Returns true if this is null. @@ -1098,12 +1098,12 @@ impl ValueRef<'_> { } /// Cast itself to binary slice. - pub fn as_binary(&self) -> Result> { + pub fn as_binary(&self) -> Result> { impl_as_for_value_ref!(self, Binary) } /// Cast itself to string slice. - pub fn as_string(&self) -> Result> { + pub fn as_string(&self) -> Result> { impl_as_for_value_ref!(self, String) } diff --git a/tests/cases/standalone/common/types/vector/vector.result b/tests/cases/standalone/common/types/vector/vector.result index d9b5a2e61e70..5685d5abde81 100644 --- a/tests/cases/standalone/common/types/vector/vector.result +++ b/tests/cases/standalone/common/types/vector/vector.result @@ -31,6 +31,180 @@ SELECT * FROM t; | 1970-01-01 00:00:00.003000 | "[7,8,9]" | +----------------------------+-----------+ +SELECT cos_distance(v, '[0.0, 0.0, 0.0]') FROM t; + ++-------------------------------------------+ +| cos_distance(t.v,Utf8("[0.0, 0.0, 0.0]")) | ++-------------------------------------------+ +| 1.0 | +| 1.0 | +| 1.0 | ++-------------------------------------------+ + +SELECT *, cos_distance(v, '[0.0, 0.0, 0.0]') as d FROM t ORDER BY d; + ++-------------------------+--------------------------+-----+ +| ts | v | d | ++-------------------------+--------------------------+-----+ +| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 1.0 | +| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 1.0 | +| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 1.0 | ++-------------------------+--------------------------+-----+ + +SELECT cos_distance('[1.0, 2.0, 3.0]', v) FROM t; + ++-------------------------------------------+ +| cos_distance(Utf8("[1.0, 2.0, 3.0]"),t.v) | ++-------------------------------------------+ +| 5.9604645e-8 | +| 0.025368154 | +| 0.04058808 | ++-------------------------------------------+ + +SELECT *, cos_distance('[1.0, 2.0, 3.0]', v) as d FROM t ORDER BY d; + ++-------------------------+--------------------------+--------------+ +| ts | v | d | ++-------------------------+--------------------------+--------------+ +| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 5.9604645e-8 | +| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 0.025368154 | +| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.04058808 | ++-------------------------+--------------------------+--------------+ + +SELECT cos_distance(v, v) FROM t; + ++-----------------------+ +| cos_distance(t.v,t.v) | ++-----------------------+ +| 5.9604645e-8 | +| 0.0 | +| 5.9604645e-8 | ++-----------------------+ + +SELECT cos_distance(v, '[1.0]') FROM t; + +Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1 + +SELECT cos_distance(v, 1.0) FROM t; + +Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2 + +SELECT l2sq_distance(v, '[0.0, 0.0, 0.0]') FROM t; + ++--------------------------------------------+ +| l2sq_distance(t.v,Utf8("[0.0, 0.0, 0.0]")) | ++--------------------------------------------+ +| 14.0 | +| 77.0 | +| 194.0 | ++--------------------------------------------+ + +SELECT *, l2sq_distance(v, '[0.0, 0.0, 0.0]') as d FROM t ORDER BY d; + ++-------------------------+--------------------------+-------+ +| ts | v | d | ++-------------------------+--------------------------+-------+ +| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 14.0 | +| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 77.0 | +| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 194.0 | ++-------------------------+--------------------------+-------+ + +SELECT l2sq_distance('[1.0, 2.0, 3.0]', v) FROM t; + ++--------------------------------------------+ +| l2sq_distance(Utf8("[1.0, 2.0, 3.0]"),t.v) | ++--------------------------------------------+ +| 0.0 | +| 27.0 | +| 108.0 | ++--------------------------------------------+ + +SELECT *, l2sq_distance('[1.0, 2.0, 3.0]', v) as d FROM t ORDER BY d; + ++-------------------------+--------------------------+-------+ +| ts | v | d | ++-------------------------+--------------------------+-------+ +| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 0.0 | +| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 27.0 | +| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 108.0 | ++-------------------------+--------------------------+-------+ + +SELECT l2sq_distance(v, v) FROM t; + ++------------------------+ +| l2sq_distance(t.v,t.v) | ++------------------------+ +| 0.0 | +| 0.0 | +| 0.0 | ++------------------------+ + +SELECT l2sq_distance(v, '[1.0]') FROM t; + +Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1 + +SELECT l2sq_distance(v, 1.0) FROM t; + +Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2 + +SELECT dot_product(v, '[0.0, 0.0, 0.0]') FROM t; + ++-------------------------------------------+ +| dot_product(t.v,Utf8("[0.0, 0.0, 0.0]")) | ++-------------------------------------------+ +| 0.0 | +| 0.0 | +| 0.0 | ++-------------------------------------------+ + +SELECT *, dot_product(v, '[0.0, 0.0, 0.0]') as d FROM t ORDER BY d; + ++-------------------------+--------------------------+-----+ +| ts | v | d | ++-------------------------+--------------------------+-----+ +| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 0.0 | +| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 0.0 | +| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.0 | ++-------------------------+--------------------------+-----+ + +SELECT dot_product('[1.0, 2.0, 3.0]', v) FROM t; + ++-------------------------------------------+ +| dot_product(Utf8("[1.0, 2.0, 3.0]"),t.v) | ++-------------------------------------------+ +| 14.0 | +| 32.0 | +| 50.0 | ++-------------------------------------------+ + +SELECT *, dot_product('[1.0, 2.0, 3.0]', v) as d FROM t ORDER BY d; + ++-------------------------+--------------------------+------+ +| ts | v | d | ++-------------------------+--------------------------+------+ +| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 14.0 | +| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 32.0 | +| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 50.0 | ++-------------------------+--------------------------+------+ + +SELECT dot_product(v, v) FROM t; + ++-----------------------+ +| dot_product(t.v,t.v) | ++-----------------------+ +| 14.0 | +| 77.0 | +| 194.0 | ++-----------------------+ + +SELECT dot_product(v, '[1.0]') FROM t; + +Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1 + +SELECT dot_product(v, 1.0) FROM t; + +Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2 + -- Unexpected dimension -- INSERT INTO t VALUES (4, "[1.0]"); diff --git a/tests/cases/standalone/common/types/vector/vector.sql b/tests/cases/standalone/common/types/vector/vector.sql index 376f356aaa66..d6fdd2cf19d6 100644 --- a/tests/cases/standalone/common/types/vector/vector.sql +++ b/tests/cases/standalone/common/types/vector/vector.sql @@ -11,6 +11,50 @@ SELECT * FROM t; -- SQLNESS PROTOCOL POSTGRES SELECT * FROM t; +SELECT cos_distance(v, '[0.0, 0.0, 0.0]') FROM t; + +SELECT *, cos_distance(v, '[0.0, 0.0, 0.0]') as d FROM t ORDER BY d; + +SELECT cos_distance('[1.0, 2.0, 3.0]', v) FROM t; + +SELECT *, cos_distance('[1.0, 2.0, 3.0]', v) as d FROM t ORDER BY d; + +SELECT cos_distance(v, v) FROM t; + +SELECT cos_distance(v, '[1.0]') FROM t; + +SELECT cos_distance(v, 1.0) FROM t; + + +SELECT l2sq_distance(v, '[0.0, 0.0, 0.0]') FROM t; + +SELECT *, l2sq_distance(v, '[0.0, 0.0, 0.0]') as d FROM t ORDER BY d; + +SELECT l2sq_distance('[1.0, 2.0, 3.0]', v) FROM t; + +SELECT *, l2sq_distance('[1.0, 2.0, 3.0]', v) as d FROM t ORDER BY d; + +SELECT l2sq_distance(v, v) FROM t; + +SELECT l2sq_distance(v, '[1.0]') FROM t; + +SELECT l2sq_distance(v, 1.0) FROM t; + + +SELECT dot_product(v, '[0.0, 0.0, 0.0]') FROM t; + +SELECT *, dot_product(v, '[0.0, 0.0, 0.0]') as d FROM t ORDER BY d; + +SELECT dot_product('[1.0, 2.0, 3.0]', v) FROM t; + +SELECT *, dot_product('[1.0, 2.0, 3.0]', v) as d FROM t ORDER BY d; + +SELECT dot_product(v, v) FROM t; + +SELECT dot_product(v, '[1.0]') FROM t; + +SELECT dot_product(v, 1.0) FROM t; + -- Unexpected dimension -- INSERT INTO t VALUES (4, "[1.0]"); From 7e975032a3a43bb4217a55b899e7cf3a4178aa49 Mon Sep 17 00:00:00 2001 From: Zhenchi Date: Wed, 13 Nov 2024 10:40:08 +0000 Subject: [PATCH 2/4] fix: f64 instead Signed-off-by: Zhenchi --- .../function/src/scalars/vector/distance.rs | 8 +-- .../common/types/vector/vector.result | 60 +++++++++---------- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/src/common/function/src/scalars/vector/distance.rs b/src/common/function/src/scalars/vector/distance.rs index 8bf39d62c684..b217d0e66a3c 100644 --- a/src/common/function/src/scalars/vector/distance.rs +++ b/src/common/function/src/scalars/vector/distance.rs @@ -21,7 +21,7 @@ use common_query::prelude::Signature; use datatypes::prelude::ConcreteDataType; use datatypes::scalars::ScalarVectorBuilder; use datatypes::value::ValueRef; -use datatypes::vectors::{Float32VectorBuilder, MutableVector, Vector, VectorRef}; +use datatypes::vectors::{Float64VectorBuilder, MutableVector, Vector, VectorRef}; use snafu::ensure; use crate::function::{Function, FunctionContext}; @@ -41,7 +41,7 @@ macro_rules! define_distance_function { } fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result { - Ok(ConcreteDataType::float32_datatype()) + Ok(ConcreteDataType::float64_datatype()) } fn signature(&self) -> Signature { @@ -71,7 +71,7 @@ macro_rules! define_distance_function { let arg1 = &columns[1]; let size = arg0.len(); - let mut result = Float32VectorBuilder::with_capacity(size); + let mut result = Float64VectorBuilder::with_capacity(size); if size == 0 { return Ok(result.to_vector()); } @@ -104,7 +104,7 @@ macro_rules! define_distance_function { let f = ::$similarity_method; // Safe: checked if the length of the vectors match let d = f(vec0.as_ref(), vec1.as_ref()).unwrap(); - result.push(Some(d as f32)); + result.push(Some(d)); } else { result.push_null(); } diff --git a/tests/cases/standalone/common/types/vector/vector.result b/tests/cases/standalone/common/types/vector/vector.result index 5685d5abde81..daa3a6ea5d51 100644 --- a/tests/cases/standalone/common/types/vector/vector.result +++ b/tests/cases/standalone/common/types/vector/vector.result @@ -56,29 +56,29 @@ SELECT cos_distance('[1.0, 2.0, 3.0]', v) FROM t; +-------------------------------------------+ | cos_distance(Utf8("[1.0, 2.0, 3.0]"),t.v) | +-------------------------------------------+ -| 5.9604645e-8 | -| 0.025368154 | -| 0.04058808 | +| 5.960464477539063e-8 | +| 0.025368154048919678 | +| 0.04058808088302612 | +-------------------------------------------+ SELECT *, cos_distance('[1.0, 2.0, 3.0]', v) as d FROM t ORDER BY d; -+-------------------------+--------------------------+--------------+ -| ts | v | d | -+-------------------------+--------------------------+--------------+ -| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 5.9604645e-8 | -| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 0.025368154 | -| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.04058808 | -+-------------------------+--------------------------+--------------+ ++-------------------------+--------------------------+----------------------+ +| ts | v | d | ++-------------------------+--------------------------+----------------------+ +| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 5.960464477539063e-8 | +| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 0.025368154048919678 | +| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.04058808088302612 | ++-------------------------+--------------------------+----------------------+ SELECT cos_distance(v, v) FROM t; +-----------------------+ | cos_distance(t.v,t.v) | +-----------------------+ -| 5.9604645e-8 | +| 5.960464477539063e-8 | | 0.0 | -| 5.9604645e-8 | +| 5.960464477539063e-8 | +-----------------------+ SELECT cos_distance(v, '[1.0]') FROM t; @@ -149,13 +149,13 @@ Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 d SELECT dot_product(v, '[0.0, 0.0, 0.0]') FROM t; -+-------------------------------------------+ ++------------------------------------------+ | dot_product(t.v,Utf8("[0.0, 0.0, 0.0]")) | -+-------------------------------------------+ -| 0.0 | -| 0.0 | -| 0.0 | -+-------------------------------------------+ ++------------------------------------------+ +| 0.0 | +| 0.0 | +| 0.0 | ++------------------------------------------+ SELECT *, dot_product(v, '[0.0, 0.0, 0.0]') as d FROM t ORDER BY d; @@ -169,13 +169,13 @@ SELECT *, dot_product(v, '[0.0, 0.0, 0.0]') as d FROM t ORDER BY d; SELECT dot_product('[1.0, 2.0, 3.0]', v) FROM t; -+-------------------------------------------+ ++------------------------------------------+ | dot_product(Utf8("[1.0, 2.0, 3.0]"),t.v) | -+-------------------------------------------+ -| 14.0 | -| 32.0 | -| 50.0 | -+-------------------------------------------+ ++------------------------------------------+ +| 14.0 | +| 32.0 | +| 50.0 | ++------------------------------------------+ SELECT *, dot_product('[1.0, 2.0, 3.0]', v) as d FROM t ORDER BY d; @@ -189,13 +189,13 @@ SELECT *, dot_product('[1.0, 2.0, 3.0]', v) as d FROM t ORDER BY d; SELECT dot_product(v, v) FROM t; -+-----------------------+ ++----------------------+ | dot_product(t.v,t.v) | -+-----------------------+ -| 14.0 | -| 77.0 | -| 194.0 | -+-----------------------+ ++----------------------+ +| 14.0 | +| 77.0 | +| 194.0 | ++----------------------+ SELECT dot_product(v, '[1.0]') FROM t; From 223e183ad58fc1a4565f27f16008f26539399500 Mon Sep 17 00:00:00 2001 From: Zhenchi Date: Thu, 14 Nov 2024 06:33:43 +0000 Subject: [PATCH 3/4] address comments Signed-off-by: Zhenchi --- .../function/src/scalars/vector/distance.rs | 12 +- .../common/types/vector/vector.result | 176 +++++++++--------- .../standalone/common/types/vector/vector.sql | 37 ++-- 3 files changed, 118 insertions(+), 107 deletions(-) diff --git a/src/common/function/src/scalars/vector/distance.rs b/src/common/function/src/scalars/vector/distance.rs index b217d0e66a3c..c1259c229821 100644 --- a/src/common/function/src/scalars/vector/distance.rs +++ b/src/common/function/src/scalars/vector/distance.rs @@ -137,7 +137,7 @@ fn parse_if_constant_string(arg: &Arc) -> Result>> { arg.get_ref(0) .as_string() .unwrap() // Safe: checked if it is a string - .map(parse_vector_from_string) + .map(parse_f32_vector_from_string) .transpose() } @@ -153,7 +153,7 @@ fn as_vector(arg: ValueRef<'_>) -> Result>> { ConcreteDataType::String(_) => arg .as_string() .unwrap() // Safe: checked if it is a string - .map(|s| Ok(Cow::Owned(parse_vector_from_string(s)?))) + .map(|s| Ok(Cow::Owned(parse_f32_vector_from_string(s)?))) .transpose(), ConcreteDataType::Null(_) => Ok(None), _ => InvalidFuncArgsSnafu { @@ -181,7 +181,7 @@ fn binary_as_vector(bytes: &[u8]) -> Result<&[f32]> { /// Parse a string to a vector value. /// Valid inputs are strings like "[1.0, 2.0, 3.0]". -fn parse_vector_from_string(s: &str) -> Result> { +fn parse_f32_vector_from_string(s: &str) -> Result> { let trimmed = s.trim(); if !trimmed.starts_with('[') || !trimmed.ends_with(']') { return InvalidFuncArgsSnafu { @@ -446,13 +446,13 @@ mod tests { #[test] fn test_parse_vector_from_string() { - let result = parse_vector_from_string("[1.0, 2.0, 3.0]").unwrap(); + let result = parse_f32_vector_from_string("[1.0, 2.0, 3.0]").unwrap(); assert_eq!(result, vec![1.0, 2.0, 3.0]); - let result = parse_vector_from_string("[]").unwrap(); + let result = parse_f32_vector_from_string("[]").unwrap(); assert_eq!(result, Vec::::new()); - let result = parse_vector_from_string("[1.0, a, 3.0]"); + let result = parse_f32_vector_from_string("[1.0, a, 3.0]"); assert!(result.is_err()); } diff --git a/tests/cases/standalone/common/types/vector/vector.result b/tests/cases/standalone/common/types/vector/vector.result index daa3a6ea5d51..c8ae6a909cae 100644 --- a/tests/cases/standalone/common/types/vector/vector.result +++ b/tests/cases/standalone/common/types/vector/vector.result @@ -31,17 +31,17 @@ SELECT * FROM t; | 1970-01-01 00:00:00.003000 | "[7,8,9]" | +----------------------------+-----------+ -SELECT cos_distance(v, '[0.0, 0.0, 0.0]') FROM t; +SELECT round(cos_distance(v, '[0.0, 0.0, 0.0]'), 14) FROM t; -+-------------------------------------------+ -| cos_distance(t.v,Utf8("[0.0, 0.0, 0.0]")) | -+-------------------------------------------+ -| 1.0 | -| 1.0 | -| 1.0 | -+-------------------------------------------+ ++------------------------------------------------------------+ +| round(cos_distance(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(14)) | ++------------------------------------------------------------+ +| 1.0 | +| 1.0 | +| 1.0 | ++------------------------------------------------------------+ -SELECT *, cos_distance(v, '[0.0, 0.0, 0.0]') as d FROM t ORDER BY d; +SELECT *, round(cos_distance(v, '[0.0, 0.0, 0.0]'), 14) as d FROM t ORDER BY d; +-------------------------+--------------------------+-----+ | ts | v | d | @@ -51,55 +51,57 @@ SELECT *, cos_distance(v, '[0.0, 0.0, 0.0]') as d FROM t ORDER BY d; | 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 1.0 | +-------------------------+--------------------------+-----+ -SELECT cos_distance('[1.0, 2.0, 3.0]', v) FROM t; +SELECT round(cos_distance('[1.0, 2.0, 3.0]', v), 14) FROM t; -+-------------------------------------------+ -| cos_distance(Utf8("[1.0, 2.0, 3.0]"),t.v) | -+-------------------------------------------+ -| 5.960464477539063e-8 | -| 0.025368154048919678 | -| 0.04058808088302612 | -+-------------------------------------------+ ++------------------------------------------------------------+ +| round(cos_distance(Utf8("[1.0, 2.0, 3.0]"),t.v),Int64(14)) | ++------------------------------------------------------------+ +| 5.960464e-8 | +| 0.02536815404892 | +| 0.04058808088303 | ++------------------------------------------------------------+ -SELECT *, cos_distance('[1.0, 2.0, 3.0]', v) as d FROM t ORDER BY d; +SELECT *, round(cos_distance('[1.0, 2.0, 3.0]', v), 14) as d FROM t ORDER BY d; -+-------------------------+--------------------------+----------------------+ -| ts | v | d | -+-------------------------+--------------------------+----------------------+ -| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 5.960464477539063e-8 | -| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 0.025368154048919678 | -| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.04058808088302612 | -+-------------------------+--------------------------+----------------------+ ++-------------------------+--------------------------+------------------+ +| ts | v | d | ++-------------------------+--------------------------+------------------+ +| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 5.960464e-8 | +| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 0.02536815404892 | +| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.04058808088303 | ++-------------------------+--------------------------+------------------+ -SELECT cos_distance(v, v) FROM t; +SELECT round(cos_distance(v, v), 14) FROM t; -+-----------------------+ -| cos_distance(t.v,t.v) | -+-----------------------+ -| 5.960464477539063e-8 | -| 0.0 | -| 5.960464477539063e-8 | -+-----------------------+ ++----------------------------------------+ +| round(cos_distance(t.v,t.v),Int64(14)) | ++----------------------------------------+ +| 5.960464e-8 | +| 0.0 | +| 5.960464e-8 | ++----------------------------------------+ +-- Unexpected dimension -- SELECT cos_distance(v, '[1.0]') FROM t; Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1 +-- Invalid type -- SELECT cos_distance(v, 1.0) FROM t; Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2 -SELECT l2sq_distance(v, '[0.0, 0.0, 0.0]') FROM t; +SELECT round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 14) FROM t; -+--------------------------------------------+ -| l2sq_distance(t.v,Utf8("[0.0, 0.0, 0.0]")) | -+--------------------------------------------+ -| 14.0 | -| 77.0 | -| 194.0 | -+--------------------------------------------+ ++-------------------------------------------------------------+ +| round(l2sq_distance(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(14)) | ++-------------------------------------------------------------+ +| 14.0 | +| 77.0 | +| 194.0 | ++-------------------------------------------------------------+ -SELECT *, l2sq_distance(v, '[0.0, 0.0, 0.0]') as d FROM t ORDER BY d; +SELECT *, round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 14) as d FROM t ORDER BY d; +-------------------------+--------------------------+-------+ | ts | v | d | @@ -109,17 +111,17 @@ SELECT *, l2sq_distance(v, '[0.0, 0.0, 0.0]') as d FROM t ORDER BY d; | 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 194.0 | +-------------------------+--------------------------+-------+ -SELECT l2sq_distance('[1.0, 2.0, 3.0]', v) FROM t; +SELECT round(l2sq_distance('[1.0, 2.0, 3.0]', v), 14) FROM t; -+--------------------------------------------+ -| l2sq_distance(Utf8("[1.0, 2.0, 3.0]"),t.v) | -+--------------------------------------------+ -| 0.0 | -| 27.0 | -| 108.0 | -+--------------------------------------------+ ++-------------------------------------------------------------+ +| round(l2sq_distance(Utf8("[1.0, 2.0, 3.0]"),t.v),Int64(14)) | ++-------------------------------------------------------------+ +| 0.0 | +| 27.0 | +| 108.0 | ++-------------------------------------------------------------+ -SELECT *, l2sq_distance('[1.0, 2.0, 3.0]', v) as d FROM t ORDER BY d; +SELECT *, round(l2sq_distance('[1.0, 2.0, 3.0]', v), 14) as d FROM t ORDER BY d; +-------------------------+--------------------------+-------+ | ts | v | d | @@ -129,35 +131,37 @@ SELECT *, l2sq_distance('[1.0, 2.0, 3.0]', v) as d FROM t ORDER BY d; | 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 108.0 | +-------------------------+--------------------------+-------+ -SELECT l2sq_distance(v, v) FROM t; +SELECT round(l2sq_distance(v, v), 14) FROM t; -+------------------------+ -| l2sq_distance(t.v,t.v) | -+------------------------+ -| 0.0 | -| 0.0 | -| 0.0 | -+------------------------+ ++-----------------------------------------+ +| round(l2sq_distance(t.v,t.v),Int64(14)) | ++-----------------------------------------+ +| 0.0 | +| 0.0 | +| 0.0 | ++-----------------------------------------+ +-- Unexpected dimension -- SELECT l2sq_distance(v, '[1.0]') FROM t; Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1 +-- Invalid type -- SELECT l2sq_distance(v, 1.0) FROM t; Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2 -SELECT dot_product(v, '[0.0, 0.0, 0.0]') FROM t; +SELECT round(dot_product(v, '[0.0, 0.0, 0.0]'), 14) FROM t; -+------------------------------------------+ -| dot_product(t.v,Utf8("[0.0, 0.0, 0.0]")) | -+------------------------------------------+ -| 0.0 | -| 0.0 | -| 0.0 | -+------------------------------------------+ ++-----------------------------------------------------------+ +| round(dot_product(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(14)) | ++-----------------------------------------------------------+ +| 0.0 | +| 0.0 | +| 0.0 | ++-----------------------------------------------------------+ -SELECT *, dot_product(v, '[0.0, 0.0, 0.0]') as d FROM t ORDER BY d; +SELECT *, round(dot_product(v, '[0.0, 0.0, 0.0]'), 14) as d FROM t ORDER BY d; +-------------------------+--------------------------+-----+ | ts | v | d | @@ -167,17 +171,17 @@ SELECT *, dot_product(v, '[0.0, 0.0, 0.0]') as d FROM t ORDER BY d; | 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.0 | +-------------------------+--------------------------+-----+ -SELECT dot_product('[1.0, 2.0, 3.0]', v) FROM t; +SELECT round(dot_product('[1.0, 2.0, 3.0]', v), 14) FROM t; -+------------------------------------------+ -| dot_product(Utf8("[1.0, 2.0, 3.0]"),t.v) | -+------------------------------------------+ -| 14.0 | -| 32.0 | -| 50.0 | -+------------------------------------------+ ++-----------------------------------------------------------+ +| round(dot_product(Utf8("[1.0, 2.0, 3.0]"),t.v),Int64(14)) | ++-----------------------------------------------------------+ +| 14.0 | +| 32.0 | +| 50.0 | ++-----------------------------------------------------------+ -SELECT *, dot_product('[1.0, 2.0, 3.0]', v) as d FROM t ORDER BY d; +SELECT *, round(dot_product('[1.0, 2.0, 3.0]', v), 14) as d FROM t ORDER BY d; +-------------------------+--------------------------+------+ | ts | v | d | @@ -187,20 +191,22 @@ SELECT *, dot_product('[1.0, 2.0, 3.0]', v) as d FROM t ORDER BY d; | 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 50.0 | +-------------------------+--------------------------+------+ -SELECT dot_product(v, v) FROM t; +SELECT round(dot_product(v, v), 14) FROM t; -+----------------------+ -| dot_product(t.v,t.v) | -+----------------------+ -| 14.0 | -| 77.0 | -| 194.0 | -+----------------------+ ++---------------------------------------+ +| round(dot_product(t.v,t.v),Int64(14)) | ++---------------------------------------+ +| 14.0 | +| 77.0 | +| 194.0 | ++---------------------------------------+ +-- Unexpected dimension -- SELECT dot_product(v, '[1.0]') FROM t; Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1 +-- Invalid type -- SELECT dot_product(v, 1.0) FROM t; Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2 diff --git a/tests/cases/standalone/common/types/vector/vector.sql b/tests/cases/standalone/common/types/vector/vector.sql index d6fdd2cf19d6..25b4d1bbe2aa 100644 --- a/tests/cases/standalone/common/types/vector/vector.sql +++ b/tests/cases/standalone/common/types/vector/vector.sql @@ -11,48 +11,53 @@ SELECT * FROM t; -- SQLNESS PROTOCOL POSTGRES SELECT * FROM t; -SELECT cos_distance(v, '[0.0, 0.0, 0.0]') FROM t; +SELECT round(cos_distance(v, '[0.0, 0.0, 0.0]'), 14) FROM t; -SELECT *, cos_distance(v, '[0.0, 0.0, 0.0]') as d FROM t ORDER BY d; +SELECT *, round(cos_distance(v, '[0.0, 0.0, 0.0]'), 14) as d FROM t ORDER BY d; -SELECT cos_distance('[1.0, 2.0, 3.0]', v) FROM t; +SELECT round(cos_distance('[1.0, 2.0, 3.0]', v), 14) FROM t; -SELECT *, cos_distance('[1.0, 2.0, 3.0]', v) as d FROM t ORDER BY d; +SELECT *, round(cos_distance('[1.0, 2.0, 3.0]', v), 14) as d FROM t ORDER BY d; -SELECT cos_distance(v, v) FROM t; +SELECT round(cos_distance(v, v), 14) FROM t; +-- Unexpected dimension -- SELECT cos_distance(v, '[1.0]') FROM t; +-- Invalid type -- SELECT cos_distance(v, 1.0) FROM t; +SELECT round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 14) FROM t; -SELECT l2sq_distance(v, '[0.0, 0.0, 0.0]') FROM t; - -SELECT *, l2sq_distance(v, '[0.0, 0.0, 0.0]') as d FROM t ORDER BY d; +SELECT *, round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 14) as d FROM t ORDER BY d; -SELECT l2sq_distance('[1.0, 2.0, 3.0]', v) FROM t; +SELECT round(l2sq_distance('[1.0, 2.0, 3.0]', v), 14) FROM t; -SELECT *, l2sq_distance('[1.0, 2.0, 3.0]', v) as d FROM t ORDER BY d; +SELECT *, round(l2sq_distance('[1.0, 2.0, 3.0]', v), 14) as d FROM t ORDER BY d; -SELECT l2sq_distance(v, v) FROM t; +SELECT round(l2sq_distance(v, v), 14) FROM t; +-- Unexpected dimension -- SELECT l2sq_distance(v, '[1.0]') FROM t; +-- Invalid type -- SELECT l2sq_distance(v, 1.0) FROM t; -SELECT dot_product(v, '[0.0, 0.0, 0.0]') FROM t; +SELECT round(dot_product(v, '[0.0, 0.0, 0.0]'), 14) FROM t; -SELECT *, dot_product(v, '[0.0, 0.0, 0.0]') as d FROM t ORDER BY d; +SELECT *, round(dot_product(v, '[0.0, 0.0, 0.0]'), 14) as d FROM t ORDER BY d; -SELECT dot_product('[1.0, 2.0, 3.0]', v) FROM t; +SELECT round(dot_product('[1.0, 2.0, 3.0]', v), 14) FROM t; -SELECT *, dot_product('[1.0, 2.0, 3.0]', v) as d FROM t ORDER BY d; +SELECT *, round(dot_product('[1.0, 2.0, 3.0]', v), 14) as d FROM t ORDER BY d; -SELECT dot_product(v, v) FROM t; +SELECT round(dot_product(v, v), 14) FROM t; +-- Unexpected dimension -- SELECT dot_product(v, '[1.0]') FROM t; +-- Invalid type -- SELECT dot_product(v, 1.0) FROM t; -- Unexpected dimension -- From ac7b0174d662bbcf047cc2344d3b2538b8b34b6b Mon Sep 17 00:00:00 2001 From: Zhenchi Date: Thu, 14 Nov 2024 09:02:05 +0000 Subject: [PATCH 4/4] tiny adjust Signed-off-by: Zhenchi --- .../common/types/vector/vector.result | 188 +++++++++--------- .../standalone/common/types/vector/vector.sql | 30 +-- 2 files changed, 109 insertions(+), 109 deletions(-) diff --git a/tests/cases/standalone/common/types/vector/vector.result b/tests/cases/standalone/common/types/vector/vector.result index c8ae6a909cae..ee9bbf45af25 100644 --- a/tests/cases/standalone/common/types/vector/vector.result +++ b/tests/cases/standalone/common/types/vector/vector.result @@ -31,17 +31,17 @@ SELECT * FROM t; | 1970-01-01 00:00:00.003000 | "[7,8,9]" | +----------------------------+-----------+ -SELECT round(cos_distance(v, '[0.0, 0.0, 0.0]'), 14) FROM t; +SELECT round(cos_distance(v, '[0.0, 0.0, 0.0]'), 4) FROM t; -+------------------------------------------------------------+ -| round(cos_distance(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(14)) | -+------------------------------------------------------------+ -| 1.0 | -| 1.0 | -| 1.0 | -+------------------------------------------------------------+ ++-----------------------------------------------------------+ +| round(cos_distance(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(4)) | ++-----------------------------------------------------------+ +| 1.0 | +| 1.0 | +| 1.0 | ++-----------------------------------------------------------+ -SELECT *, round(cos_distance(v, '[0.0, 0.0, 0.0]'), 14) as d FROM t ORDER BY d; +SELECT *, round(cos_distance(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d; +-------------------------+--------------------------+-----+ | ts | v | d | @@ -51,35 +51,35 @@ SELECT *, round(cos_distance(v, '[0.0, 0.0, 0.0]'), 14) as d FROM t ORDER BY d; | 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 1.0 | +-------------------------+--------------------------+-----+ -SELECT round(cos_distance('[1.0, 2.0, 3.0]', v), 14) FROM t; +SELECT round(cos_distance('[7.0, 8.0, 9.0]', v), 4) FROM t; -+------------------------------------------------------------+ -| round(cos_distance(Utf8("[1.0, 2.0, 3.0]"),t.v),Int64(14)) | -+------------------------------------------------------------+ -| 5.960464e-8 | -| 0.02536815404892 | -| 0.04058808088303 | -+------------------------------------------------------------+ ++-----------------------------------------------------------+ +| round(cos_distance(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(4)) | ++-----------------------------------------------------------+ +| 0.0406 | +| 0.0018 | +| 0.0 | ++-----------------------------------------------------------+ -SELECT *, round(cos_distance('[1.0, 2.0, 3.0]', v), 14) as d FROM t ORDER BY d; +SELECT *, round(cos_distance('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d; -+-------------------------+--------------------------+------------------+ -| ts | v | d | -+-------------------------+--------------------------+------------------+ -| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 5.960464e-8 | -| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 0.02536815404892 | -| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.04058808088303 | -+-------------------------+--------------------------+------------------+ ++-------------------------+--------------------------+--------+ +| ts | v | d | ++-------------------------+--------------------------+--------+ +| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.0 | +| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 0.0018 | +| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 0.0406 | ++-------------------------+--------------------------+--------+ -SELECT round(cos_distance(v, v), 14) FROM t; +SELECT round(cos_distance(v, v), 4) FROM t; -+----------------------------------------+ -| round(cos_distance(t.v,t.v),Int64(14)) | -+----------------------------------------+ -| 5.960464e-8 | -| 0.0 | -| 5.960464e-8 | -+----------------------------------------+ ++---------------------------------------+ +| round(cos_distance(t.v,t.v),Int64(4)) | ++---------------------------------------+ +| 0.0 | +| 0.0 | +| 0.0 | ++---------------------------------------+ -- Unexpected dimension -- SELECT cos_distance(v, '[1.0]') FROM t; @@ -91,17 +91,17 @@ SELECT cos_distance(v, 1.0) FROM t; Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2 -SELECT round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 14) FROM t; +SELECT round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 4) FROM t; -+-------------------------------------------------------------+ -| round(l2sq_distance(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(14)) | -+-------------------------------------------------------------+ -| 14.0 | -| 77.0 | -| 194.0 | -+-------------------------------------------------------------+ ++------------------------------------------------------------+ +| round(l2sq_distance(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(4)) | ++------------------------------------------------------------+ +| 14.0 | +| 77.0 | +| 194.0 | ++------------------------------------------------------------+ -SELECT *, round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 14) as d FROM t ORDER BY d; +SELECT *, round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d; +-------------------------+--------------------------+-------+ | ts | v | d | @@ -111,35 +111,35 @@ SELECT *, round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 14) as d FROM t ORDER BY d; | 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 194.0 | +-------------------------+--------------------------+-------+ -SELECT round(l2sq_distance('[1.0, 2.0, 3.0]', v), 14) FROM t; +SELECT round(l2sq_distance('[7.0, 8.0, 9.0]', v), 4) FROM t; -+-------------------------------------------------------------+ -| round(l2sq_distance(Utf8("[1.0, 2.0, 3.0]"),t.v),Int64(14)) | -+-------------------------------------------------------------+ -| 0.0 | -| 27.0 | -| 108.0 | -+-------------------------------------------------------------+ ++------------------------------------------------------------+ +| round(l2sq_distance(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(4)) | ++------------------------------------------------------------+ +| 108.0 | +| 27.0 | +| 0.0 | ++------------------------------------------------------------+ -SELECT *, round(l2sq_distance('[1.0, 2.0, 3.0]', v), 14) as d FROM t ORDER BY d; +SELECT *, round(l2sq_distance('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d; +-------------------------+--------------------------+-------+ | ts | v | d | +-------------------------+--------------------------+-------+ -| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 0.0 | +| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.0 | | 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 27.0 | -| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 108.0 | +| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 108.0 | +-------------------------+--------------------------+-------+ -SELECT round(l2sq_distance(v, v), 14) FROM t; +SELECT round(l2sq_distance(v, v), 4) FROM t; -+-----------------------------------------+ -| round(l2sq_distance(t.v,t.v),Int64(14)) | -+-----------------------------------------+ -| 0.0 | -| 0.0 | -| 0.0 | -+-----------------------------------------+ ++----------------------------------------+ +| round(l2sq_distance(t.v,t.v),Int64(4)) | ++----------------------------------------+ +| 0.0 | +| 0.0 | +| 0.0 | ++----------------------------------------+ -- Unexpected dimension -- SELECT l2sq_distance(v, '[1.0]') FROM t; @@ -151,17 +151,17 @@ SELECT l2sq_distance(v, 1.0) FROM t; Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2 -SELECT round(dot_product(v, '[0.0, 0.0, 0.0]'), 14) FROM t; +SELECT round(dot_product(v, '[0.0, 0.0, 0.0]'), 4) FROM t; -+-----------------------------------------------------------+ -| round(dot_product(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(14)) | -+-----------------------------------------------------------+ -| 0.0 | -| 0.0 | -| 0.0 | -+-----------------------------------------------------------+ ++----------------------------------------------------------+ +| round(dot_product(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(4)) | ++----------------------------------------------------------+ +| 0.0 | +| 0.0 | +| 0.0 | ++----------------------------------------------------------+ -SELECT *, round(dot_product(v, '[0.0, 0.0, 0.0]'), 14) as d FROM t ORDER BY d; +SELECT *, round(dot_product(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d; +-------------------------+--------------------------+-----+ | ts | v | d | @@ -171,35 +171,35 @@ SELECT *, round(dot_product(v, '[0.0, 0.0, 0.0]'), 14) as d FROM t ORDER BY d; | 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.0 | +-------------------------+--------------------------+-----+ -SELECT round(dot_product('[1.0, 2.0, 3.0]', v), 14) FROM t; +SELECT round(dot_product('[7.0, 8.0, 9.0]', v), 4) FROM t; -+-----------------------------------------------------------+ -| round(dot_product(Utf8("[1.0, 2.0, 3.0]"),t.v),Int64(14)) | -+-----------------------------------------------------------+ -| 14.0 | -| 32.0 | -| 50.0 | -+-----------------------------------------------------------+ ++----------------------------------------------------------+ +| round(dot_product(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(4)) | ++----------------------------------------------------------+ +| 50.0 | +| 122.0 | +| 194.0 | ++----------------------------------------------------------+ -SELECT *, round(dot_product('[1.0, 2.0, 3.0]', v), 14) as d FROM t ORDER BY d; +SELECT *, round(dot_product('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d; -+-------------------------+--------------------------+------+ -| ts | v | d | -+-------------------------+--------------------------+------+ -| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 14.0 | -| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 32.0 | -| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 50.0 | -+-------------------------+--------------------------+------+ ++-------------------------+--------------------------+-------+ +| ts | v | d | ++-------------------------+--------------------------+-------+ +| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 50.0 | +| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 122.0 | +| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 194.0 | ++-------------------------+--------------------------+-------+ -SELECT round(dot_product(v, v), 14) FROM t; +SELECT round(dot_product(v, v), 4) FROM t; -+---------------------------------------+ -| round(dot_product(t.v,t.v),Int64(14)) | -+---------------------------------------+ -| 14.0 | -| 77.0 | -| 194.0 | -+---------------------------------------+ ++--------------------------------------+ +| round(dot_product(t.v,t.v),Int64(4)) | ++--------------------------------------+ +| 14.0 | +| 77.0 | +| 194.0 | ++--------------------------------------+ -- Unexpected dimension -- SELECT dot_product(v, '[1.0]') FROM t; diff --git a/tests/cases/standalone/common/types/vector/vector.sql b/tests/cases/standalone/common/types/vector/vector.sql index 25b4d1bbe2aa..cea3ef406c63 100644 --- a/tests/cases/standalone/common/types/vector/vector.sql +++ b/tests/cases/standalone/common/types/vector/vector.sql @@ -11,15 +11,15 @@ SELECT * FROM t; -- SQLNESS PROTOCOL POSTGRES SELECT * FROM t; -SELECT round(cos_distance(v, '[0.0, 0.0, 0.0]'), 14) FROM t; +SELECT round(cos_distance(v, '[0.0, 0.0, 0.0]'), 4) FROM t; -SELECT *, round(cos_distance(v, '[0.0, 0.0, 0.0]'), 14) as d FROM t ORDER BY d; +SELECT *, round(cos_distance(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d; -SELECT round(cos_distance('[1.0, 2.0, 3.0]', v), 14) FROM t; +SELECT round(cos_distance('[7.0, 8.0, 9.0]', v), 4) FROM t; -SELECT *, round(cos_distance('[1.0, 2.0, 3.0]', v), 14) as d FROM t ORDER BY d; +SELECT *, round(cos_distance('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d; -SELECT round(cos_distance(v, v), 14) FROM t; +SELECT round(cos_distance(v, v), 4) FROM t; -- Unexpected dimension -- SELECT cos_distance(v, '[1.0]') FROM t; @@ -27,15 +27,15 @@ SELECT cos_distance(v, '[1.0]') FROM t; -- Invalid type -- SELECT cos_distance(v, 1.0) FROM t; -SELECT round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 14) FROM t; +SELECT round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 4) FROM t; -SELECT *, round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 14) as d FROM t ORDER BY d; +SELECT *, round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d; -SELECT round(l2sq_distance('[1.0, 2.0, 3.0]', v), 14) FROM t; +SELECT round(l2sq_distance('[7.0, 8.0, 9.0]', v), 4) FROM t; -SELECT *, round(l2sq_distance('[1.0, 2.0, 3.0]', v), 14) as d FROM t ORDER BY d; +SELECT *, round(l2sq_distance('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d; -SELECT round(l2sq_distance(v, v), 14) FROM t; +SELECT round(l2sq_distance(v, v), 4) FROM t; -- Unexpected dimension -- SELECT l2sq_distance(v, '[1.0]') FROM t; @@ -44,15 +44,15 @@ SELECT l2sq_distance(v, '[1.0]') FROM t; SELECT l2sq_distance(v, 1.0) FROM t; -SELECT round(dot_product(v, '[0.0, 0.0, 0.0]'), 14) FROM t; +SELECT round(dot_product(v, '[0.0, 0.0, 0.0]'), 4) FROM t; -SELECT *, round(dot_product(v, '[0.0, 0.0, 0.0]'), 14) as d FROM t ORDER BY d; +SELECT *, round(dot_product(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d; -SELECT round(dot_product('[1.0, 2.0, 3.0]', v), 14) FROM t; +SELECT round(dot_product('[7.0, 8.0, 9.0]', v), 4) FROM t; -SELECT *, round(dot_product('[1.0, 2.0, 3.0]', v), 14) as d FROM t ORDER BY d; +SELECT *, round(dot_product('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d; -SELECT round(dot_product(v, v), 14) FROM t; +SELECT round(dot_product(v, v), 4) FROM t; -- Unexpected dimension -- SELECT dot_product(v, '[1.0]') FROM t;