Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 63 additions & 4 deletions be/src/vec/functions/array/function_array_distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,70 @@

namespace doris::vectorized {

#if defined(__x86_64__) && (defined(__clang_major__) && (__clang_major__ > 10))
#define PRAGMA_IMPRECISE_FUNCTION_BEGIN _Pragma("float_control(precise, off, push)")
#define PRAGMA_IMPRECISE_FUNCTION_END _Pragma("float_control(pop)")

#elif defined(__GNUC__)
#define PRAGMA_IMPRECISE_FUNCTION_BEGIN \
_Pragma("GCC push_options") \
_Pragma("GCC optimize (\"unroll-loops,associative-math,no-signed-zeros\")")
#define PRAGMA_IMPRECISE_FUNCTION_END _Pragma("GCC pop_options")
#else
#define PRAGMA_IMPRECISE_FUNCTION_BEGIN
#define PRAGMA_IMPRECISE_FUNCTION_END
#endif

PRAGMA_IMPRECISE_FUNCTION_BEGIN
float L1Distance::distance(const float* x, const float* y, size_t d) {
size_t i;
float res = 0;
for (i = 0; i < d; i++) {
res += fabs(x[i] - y[i]);
}
return res;
}

float L2Distance::distance(const float* x, const float* y, size_t d) {
size_t i;
float res = 0;
for (i = 0; i < d; i++) {
const float tmp = x[i] - y[i];
res += tmp * tmp;
}
return std::sqrt(res);
}

float CosineDistance::distance(const float* x, const float* y, size_t d) {
float dot_prod = 0;
float squared_x = 0;
float squared_y = 0;
for (size_t i = 0; i < d; ++i) {
dot_prod += x[i] * y[i];
squared_x += x[i] * x[i];
squared_y += y[i] * y[i];
}
// division by zero check
if (squared_x == 0 || squared_y == 0) [[unlikely]] {
return 2.F;
}
return 1 - dot_prod / sqrt(squared_x * squared_y);
}

float InnerProduct::distance(const float* x, const float* y, size_t d) {
float res = 0.F;
for (size_t i = 0; i != d; ++i) {
res += x[i] * y[i];
}
return res;
}
PRAGMA_IMPRECISE_FUNCTION_END

void register_function_array_distance(SimpleFunctionFactory& factory) {
factory.register_function<FunctionArrayDistance<L1Distance> >();
factory.register_function<FunctionArrayDistance<L2Distance> >();
factory.register_function<FunctionArrayDistance<CosineDistance> >();
factory.register_function<FunctionArrayDistance<InnerProduct> >();
factory.register_function<FunctionArrayDistance<L1Distance>>();
factory.register_function<FunctionArrayDistance<L2Distance>>();
factory.register_function<FunctionArrayDistance<CosineDistance>>();
factory.register_function<FunctionArrayDistance<InnerProduct>>();
}

} // namespace doris::vectorized
88 changes: 19 additions & 69 deletions be/src/vec/functions/array/function_array_distance.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

#pragma once

#include <gen_cpp/Types_types.h>

#include "vec/columns/column.h"
#include "vec/columns/column_array.h"
#include "vec/columns/columns_number.h"
Expand All @@ -35,63 +37,42 @@ namespace doris::vectorized {
class L1Distance {
public:
static constexpr auto name = "l1_distance";
struct State {
double sum = 0;
};
static void accumulate(State& state, double x, double y) { state.sum += fabs(x - y); }
static double finalize(const State& state) { return state.sum; }
static float distance(const float* x, const float* y, size_t d);
};

class L2Distance {
public:
static constexpr auto name = "l2_distance";
struct State {
double sum = 0;
};
static void accumulate(State& state, double x, double y) { state.sum += (x - y) * (x - y); }
static double finalize(const State& state) { return sqrt(state.sum); }
static float distance(const float* x, const float* y, size_t d);
};

class InnerProduct {
public:
static constexpr auto name = "inner_product";
struct State {
double sum = 0;
};
static void accumulate(State& state, double x, double y) { state.sum += x * y; }
static double finalize(const State& state) { return state.sum; }
static float distance(const float* x, const float* y, size_t d);
};

class CosineDistance {
public:
static constexpr auto name = "cosine_distance";
struct State {
double dot_prod = 0;
double squared_x = 0;
double squared_y = 0;
};
static void accumulate(State& state, double x, double y) {
state.dot_prod += x * y;
state.squared_x += x * x;
state.squared_y += y * y;
}
static double finalize(const State& state) {
return 1 - state.dot_prod / sqrt(state.squared_x * state.squared_y);
}

static float distance(const float* x, const float* y, size_t d);
};

template <typename DistanceImpl>
class FunctionArrayDistance : public IFunction {
public:
using ColumnType = ColumnFloat32;

static constexpr auto name = DistanceImpl::name;
String get_name() const override { return name; }
static FunctionPtr create() { return std::make_shared<FunctionArrayDistance<DistanceImpl>>(); }
bool is_variadic() const override { return false; }
size_t get_number_of_arguments() const override { return 2; }
bool use_default_implementation_for_nulls() const override { return false; }
bool use_default_implementation_for_nulls() const override { return true; }

DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
return make_nullable(std::make_shared<DataTypeFloat64>());
return std::make_shared<DataTypeFloat32>();
}

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
Expand Down Expand Up @@ -121,27 +102,14 @@ class FunctionArrayDistance : public IFunction {
}

// prepare return data
auto dst = ColumnFloat64::create(input_rows_count);
auto dst = ColumnType::create(input_rows_count);
auto& dst_data = dst->get_data();
auto dst_null_column = ColumnUInt8::create(input_rows_count, 0);
auto& dst_null_data = dst_null_column->get_data();

const auto& offsets1 = *arr1.offsets_ptr;
const auto& offsets2 = *arr2.offsets_ptr;
const auto& nested_col1 = assert_cast<const ColumnFloat64*>(arr1.nested_col.get());
const auto& nested_col2 = assert_cast<const ColumnFloat64*>(arr2.nested_col.get());
const auto& nested_col1 = assert_cast<const ColumnType*>(arr1.nested_col.get());
const auto& nested_col2 = assert_cast<const ColumnType*>(arr2.nested_col.get());
for (ssize_t row = 0; row < offsets1.size(); ++row) {
if (arr1.array_nullmap_data && arr1.array_nullmap_data[row]) {
dst_null_data[row] = true;
continue;
}
if (arr2.array_nullmap_data && arr2.array_nullmap_data[row]) {
dst_null_data[row] = true;
continue;
}

dst_null_data[row] = false;

// Calculate actual array sizes for current row.
// For nullable arrays, we cannot compare absolute offset values directly because:
// 1. When a row is null, its offset might equal the previous offset (no elements added)
Expand All @@ -156,29 +124,11 @@ class FunctionArrayDistance : public IFunction {
get_name(), size1, size2);
}

typename DistanceImpl::State st;
for (ssize_t pos = offsets1[row - 1]; pos < offsets1[row]; ++pos) {
// Calculate corresponding position in the second array
ssize_t pos2 = offsets2[row - 1] + (pos - offsets1[row - 1]);
if (arr1.nested_nullmap_data && arr1.nested_nullmap_data[pos]) {
dst_null_data[row] = true;
break;
}
if (arr2.nested_nullmap_data && arr2.nested_nullmap_data[pos2]) {
dst_null_data[row] = true;
break;
}
DistanceImpl::accumulate(st, nested_col1->get_element(pos),
nested_col2->get_element(pos2));
}
if (!dst_null_data[row]) {
dst_data[row] = DistanceImpl::finalize(st);
dst_null_data[row] = std::isnan(dst_data[row]);
}
dst_data[row] = DistanceImpl::distance(
nested_col1->get_data().data() + offsets1[row - 1],
nested_col2->get_data().data() + offsets1[row - 1], size1);
}

block.replace_by_position(
result, ColumnNullable::create(std::move(dst), std::move(dst_null_column)));
block.replace_by_position(result, std::move(dst));
return Status::OK();
}

Expand All @@ -190,7 +140,7 @@ class FunctionArrayDistance : public IFunction {
}
auto nested_type =
remove_nullable(assert_cast<const DataTypeArray&>(*array_type).get_nested_type());
return WhichDataType(nested_type).is_float64();
return WhichDataType(nested_type).is_float32();
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
Expand All @@ -35,11 +35,11 @@
* cosine_distance function
*/
public class CosineDistance extends ScalarFunction implements ExplicitlyCastableSignature,
BinaryExpression, AlwaysNullable {
BinaryExpression, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE)
.args(ArrayType.of(DoubleType.INSTANCE), ArrayType.of(DoubleType.INSTANCE))
FunctionSignature.ret(FloatType.INSTANCE)
.args(ArrayType.of(FloatType.INSTANCE), ArrayType.of(FloatType.INSTANCE))
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
Expand All @@ -35,11 +35,11 @@
* inner_product function
*/
public class InnerProduct extends ScalarFunction implements ExplicitlyCastableSignature,
BinaryExpression, AlwaysNullable {
BinaryExpression, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE)
.args(ArrayType.of(DoubleType.INSTANCE), ArrayType.of(DoubleType.INSTANCE))
FunctionSignature.ret(FloatType.INSTANCE)
.args(ArrayType.of(FloatType.INSTANCE), ArrayType.of(FloatType.INSTANCE))
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
Expand All @@ -35,11 +35,11 @@
* l1_distance function
*/
public class L1Distance extends ScalarFunction implements ExplicitlyCastableSignature,
BinaryExpression, AlwaysNullable {
BinaryExpression, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE)
.args(ArrayType.of(DoubleType.INSTANCE), ArrayType.of(DoubleType.INSTANCE))
FunctionSignature.ret(FloatType.INSTANCE)
.args(ArrayType.of(FloatType.INSTANCE), ArrayType.of(FloatType.INSTANCE))
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
Expand All @@ -35,11 +35,11 @@
* l2_distance function
*/
public class L2Distance extends ScalarFunction implements ExplicitlyCastableSignature,
BinaryExpression, AlwaysNullable {
BinaryExpression, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE)
.args(ArrayType.of(DoubleType.INSTANCE), ArrayType.of(DoubleType.INSTANCE))
FunctionSignature.ret(FloatType.INSTANCE)
.args(ArrayType.of(FloatType.INSTANCE), ArrayType.of(FloatType.INSTANCE))
);

/**
Expand Down
Loading
Loading