diff --git a/Cargo.lock b/Cargo.lock index c99b1cc728be..40628f9190a4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1641,6 +1641,7 @@ dependencies = [ "common-expression", "common-hashtable", "common-io", + "common-vector", "crc32fast", "criterion", "ctor", @@ -2563,6 +2564,15 @@ dependencies = [ "wiremock", ] +[[package]] +name = "common-vector" +version = "0.1.0" +dependencies = [ + "approx", + "common-exception", + "ndarray", +] + [[package]] name = "compact_str" version = "0.6.1" @@ -6185,6 +6195,15 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40" +[[package]] +name = "matrixmultiply" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "add85d4dd35074e6fedc608f8c8f513a3548619a9024b751949ef0e8e45a4d84" +dependencies = [ + "rawpointer", +] + [[package]] name = "md-5" version = "0.10.5" @@ -6500,6 +6519,19 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fce7b49e1e6d8aa67232ef1c4c936c0af58756eb2db6f65c40bacb39035e7f42" +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "nix" version = "0.26.2" @@ -7890,6 +7922,12 @@ dependencies = [ "bitflags", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.7.0" diff --git a/Cargo.toml b/Cargo.toml index 51782ea69355..d469b66b5917 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ members = [ "src/common/tracing", "src/common/storage", "src/common/profile", + "src/common/vector", # Query "src/query/ast", "src/query/codegen", diff --git a/src/common/base/src/lib.rs b/src/common/base/src/lib.rs index 2e9403dae15b..5fda28d4535b 100644 --- a/src/common/base/src/lib.rs +++ b/src/common/base/src/lib.rs @@ -28,5 +28,6 @@ pub mod containers; pub mod mem_allocator; pub mod rangemap; pub mod runtime; + pub use runtime::match_join_handle; pub use runtime::set_alloc_error_hook; diff --git a/src/common/vector/Cargo.toml b/src/common/vector/Cargo.toml new file mode 100644 index 000000000000..22a74248ac50 --- /dev/null +++ b/src/common/vector/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "common-vector" +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +publish = { workspace = true } +edition = { workspace = true } + +[lib] +doctest = false +test = false + +[dependencies] # In alphabetical order +common-exception = { path = "../exception" } + +ndarray = "0.15.6" + +[build-dependencies] + +[features] + +[dev-dependencies] +approx = "0.5.1" diff --git a/src/common/vector/src/distance.rs b/src/common/vector/src/distance.rs new file mode 100644 index 000000000000..b3e37e783f74 --- /dev/null +++ b/src/common/vector/src/distance.rs @@ -0,0 +1,34 @@ +// Copyright 2023 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_exception::ErrorCode; +use common_exception::Result; +use ndarray::ArrayView; + +pub fn cosine_distance(from: &[f32], to: &[f32]) -> Result { + if from.len() != to.len() { + return Err(ErrorCode::InvalidArgument(format!( + "Vector length not equal: {:} != {:}", + from.len(), + to.len(), + ))); + } + + let a = ArrayView::from(from); + let b = ArrayView::from(to); + let aa_sum = (&a * &a).sum(); + let bb_sum = (&b * &b).sum(); + + Ok((&a * &b).sum() / ((aa_sum).sqrt() * (bb_sum).sqrt())) +} diff --git a/src/common/vector/src/lib.rs b/src/common/vector/src/lib.rs new file mode 100644 index 000000000000..9961bf518bd3 --- /dev/null +++ b/src/common/vector/src/lib.rs @@ -0,0 +1,17 @@ +// Copyright 2023 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod distance; + +pub use distance::cosine_distance; diff --git a/src/common/vector/tests/it/distance.rs b/src/common/vector/tests/it/distance.rs new file mode 100644 index 000000000000..ed0671065e89 --- /dev/null +++ b/src/common/vector/tests/it/distance.rs @@ -0,0 +1,41 @@ +// Copyright 2023 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_vector::cosine_distance; + +#[test] +fn test_cosine() { + { + let x: Vec = (1..9).map(|v| v as f32).collect(); + let y: Vec = (100..108).map(|v| v as f32).collect(); + let d = cosine_distance(&x, &y).unwrap(); + // from scipy.spatial.distance.cosine + approx::assert_relative_eq!(d, 0.900_957); + } + + { + let x = vec![3.0, 45.0, 7.0, 2.0, 5.0, 20.0, 13.0, 12.0]; + let y = vec![2.0, 54.0, 13.0, 15.0, 22.0, 34.0, 50.0, 1.0]; + let d = cosine_distance(&x, &y).unwrap(); + // from sklearn.metrics.pairwise import cosine_similarity + approx::assert_relative_eq!(d, 0.873_580_6); + } + + { + let x = vec![3.0, 45.0, 7.0, 2.0, 5.0, 20.0, 13.0, 12.0]; + let y = vec![2.0, 54.0]; + let d = cosine_distance(&x, &y); + assert!(d.is_err()); + } +} diff --git a/src/common/vector/tests/it/main.rs b/src/common/vector/tests/it/main.rs new file mode 100644 index 000000000000..9c3dd7dab600 --- /dev/null +++ b/src/common/vector/tests/it/main.rs @@ -0,0 +1,15 @@ +// Copyright 2023 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod distance; diff --git a/src/query/functions/Cargo.toml b/src/query/functions/Cargo.toml index a3a47931b08b..0c5fe91129d4 100644 --- a/src/query/functions/Cargo.toml +++ b/src/query/functions/Cargo.toml @@ -17,6 +17,7 @@ common-exception = { path = "../../common/exception" } common-expression = { path = "../expression" } common-hashtable = { path = "../../common/hashtable" } common-io = { path = "../../common/io" } +common-vector = { path = "../../common/vector" } jsonb = { workspace = true } # Crates.io dependencies diff --git a/src/query/functions/src/scalars/mod.rs b/src/query/functions/src/scalars/mod.rs index 61b334b72496..7be0bbc6c9c5 100644 --- a/src/query/functions/src/scalars/mod.rs +++ b/src/query/functions/src/scalars/mod.rs @@ -25,6 +25,7 @@ mod map; mod math; mod tuple; mod variant; +mod vector; mod comparison; mod decimal; @@ -55,4 +56,5 @@ pub fn register(registry: &mut FunctionRegistry) { hash::register(registry); other::register(registry); decimal::register(registry); + vector::register(registry); } diff --git a/src/query/functions/src/scalars/vector.rs b/src/query/functions/src/scalars/vector.rs new file mode 100644 index 000000000000..461aed0db453 --- /dev/null +++ b/src/query/functions/src/scalars/vector.rs @@ -0,0 +1,47 @@ +// Copyright 2023 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_arrow::arrow::buffer::Buffer; +use common_expression::types::ArrayType; +use common_expression::types::Float32Type; +use common_expression::types::F32; +use common_expression::vectorize_with_builder_2_arg; +use common_expression::FunctionDomain; +use common_expression::FunctionRegistry; +use common_vector::cosine_distance; + +pub fn register(registry: &mut FunctionRegistry) { + registry.register_passthrough_nullable_2_arg::, ArrayType, Float32Type, _, _>( + "cosine_distance", + |_, _| FunctionDomain::MayThrow, + vectorize_with_builder_2_arg::, ArrayType, Float32Type>( + |lhs, rhs, output, ctx| { + let l_f32= + unsafe { std::mem::transmute::, Buffer>(lhs) }; + let r_f32= + unsafe { std::mem::transmute::, Buffer>(rhs) }; + + match cosine_distance(l_f32.as_slice(), r_f32.as_slice()) { + Ok(dist) => { + output.push(F32::from(dist)); + } + Err(err) => { + ctx.set_error(output.len(), err.to_string()); + output.push(F32::from(0.0)); + } + } + } + ), + ); +} diff --git a/src/query/functions/tests/it/scalars/mod.rs b/src/query/functions/tests/it/scalars/mod.rs index d5a8870cd345..3ccd2d400a57 100644 --- a/src/query/functions/tests/it/scalars/mod.rs +++ b/src/query/functions/tests/it/scalars/mod.rs @@ -47,6 +47,7 @@ mod regexp; mod string; mod tuple; mod variant; +mod vector; pub fn run_ast(file: &mut impl Write, text: impl AsRef, columns: &[(&str, Column)]) { let text = text.as_ref(); diff --git a/src/query/functions/tests/it/scalars/testdata/function_list.txt b/src/query/functions/tests/it/scalars/testdata/function_list.txt index c16b0d235f97..1eae53e1aa66 100644 --- a/src/query/functions/tests/it/scalars/testdata/function_list.txt +++ b/src/query/functions/tests/it/scalars/testdata/function_list.txt @@ -1098,6 +1098,8 @@ Functions overloads: 28 contains(Array(T0), T0) :: Boolean 0 cos(Float64) :: Float64 1 cos(Float64 NULL) :: Float64 NULL +0 cosine_distance(Array(Float32), Array(Float32)) :: Float32 +1 cosine_distance(Array(Float32) NULL, Array(Float32) NULL) :: Float32 NULL 0 cot(Float64) :: Float64 1 cot(Float64 NULL) :: Float64 NULL 0 crc32(String) :: UInt32 diff --git a/src/query/functions/tests/it/scalars/testdata/vector.txt b/src/query/functions/tests/it/scalars/testdata/vector.txt new file mode 100644 index 000000000000..b714769ff014 --- /dev/null +++ b/src/query/functions/tests/it/scalars/testdata/vector.txt @@ -0,0 +1,23 @@ +ast : cosine_distance([a], [b]) +raw expr : cosine_distance(array(a::Float32), array(b::Float32)) +checked expr : cosine_distance(array(a), array(b)) +evaluation: ++--------+---------+---------+---------+ +| | a | b | Output | ++--------+---------+---------+---------+ +| Type | Float32 | Float32 | Float32 | +| Domain | {0..=2} | {3..=5} | Unknown | +| Row 0 | 0 | 3 | NaN | +| Row 1 | 1 | 4 | 1 | +| Row 2 | 2 | 5 | 1 | ++--------+---------+---------+---------+ +evaluation (internal): ++--------+----------------------+ +| Column | Data | ++--------+----------------------+ +| a | Float32([0, 1, 2]) | +| b | Float32([3, 4, 5]) | +| Output | Float32([NaN, 1, 1]) | ++--------+----------------------+ + + diff --git a/src/query/functions/tests/it/scalars/vector.rs b/src/query/functions/tests/it/scalars/vector.rs new file mode 100644 index 000000000000..688c30f6ec2b --- /dev/null +++ b/src/query/functions/tests/it/scalars/vector.rs @@ -0,0 +1,36 @@ +// Copyright 2023 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::io::Write; + +use common_expression::types::*; +use common_expression::FromData; +use goldenfile::Mint; + +use super::run_ast; + +#[test] +fn test_vector() { + let mut mint = Mint::new("tests/it/scalars/testdata"); + let file = &mut mint.new_goldenfile("vector.txt").unwrap(); + + test_vector_cosine_distance(file); +} + +fn test_vector_cosine_distance(file: &mut impl Write) { + run_ast(file, "cosine_distance([a], [b])", &[ + ("a", Float32Type::from_data(vec![0f32, 1.0, 2.0])), + ("b", Float32Type::from_data(vec![3f32, 4.0, 5.0])), + ]); +} diff --git a/tests/sqllogictests/suites/query/02_function/02_0063_function_vector b/tests/sqllogictests/suites/query/02_function/02_0063_function_vector new file mode 100644 index 000000000000..e32a6538e4c1 --- /dev/null +++ b/tests/sqllogictests/suites/query/02_function/02_0063_function_vector @@ -0,0 +1,5 @@ +# From sklearn.metrics.pairwise import cosine_similarity +query F +select cosine_distance([3.0, 45.0, 7.0, 2.0, 5.0, 20.0, 13.0, 12.0], [2.0, 54.0, 13.0, 15.0, 22.0, 34.0, 50.0, 1.0]) as sim +---- +0.8735807