Skip to content

Commit

Permalink
Implement mul for vecf16 & veci8.
Browse files Browse the repository at this point in the history
Signed-off-by: my-vegetable-has-exploded <wy1109468038@gmail.com>
  • Loading branch information
my-vegetable-has-exploded committed May 19, 2024
1 parent 29448ca commit 09699cb
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/datatype/operators_vecf16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ fn _vectors_vecf16_operator_minus(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) ->
Vecf16Output::new(Vecf16Borrowed::new(&v))
}

/// Calculate the element-wise multiplication of two f16 vectors.
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_vecf16_operator_mul(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> Vecf16Output {
let n = check_matched_dims(lhs.dims(), rhs.dims());
let mut v = vec![F16::zero(); n];
for i in 0..n {
v[i] = lhs[i] * rhs[i];
}
Vecf16Output::new(Vecf16Borrowed::new(&v))
}

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_vecf16_operator_lt(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool {
check_matched_dims(lhs.dims(), rhs.dims());
Expand Down
14 changes: 14 additions & 0 deletions src/datatype/operators_veci8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ fn _vectors_veci8_operator_minus(lhs: Veci8Input<'_>, rhs: Veci8Input<'_>) -> Ve
)
}

/// Calculate the element-wise multiplication of two i8 vectors.
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_veci8_operator_mul(lhs: Veci8Input<'_>, rhs: Veci8Input<'_>) -> Veci8Output {
check_matched_dims(lhs.len(), rhs.len());
let data = (0..lhs.len())
.map(|i| lhs.index(i) * rhs.index(i))
.collect::<Vec<_>>();
let (vector, alpha, offset) = veci8::i8_quantization(&data);
let (sum, l2_norm) = veci8::i8_precompute(&vector, alpha, offset);
Veci8Output::new(
Veci8Borrowed::new_checked(lhs.len() as u32, &vector, alpha, offset, sum, l2_norm).unwrap(),
)
}

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_veci8_operator_lt(lhs: Veci8Input<'_>, rhs: Veci8Input<'_>) -> bool {
check_matched_dims(lhs.len(), rhs.len());
Expand Down
14 changes: 14 additions & 0 deletions src/sql/finalize.sql
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,27 @@ CREATE OPERATOR * (
COMMUTATOR = *
);

CREATE OPERATOR * (
PROCEDURE = _vectors_vecf16_operator_mul,
LEFTARG = vecf16,
RIGHTARG = vecf16,
COMMUTATOR = *
);

CREATE OPERATOR * (
PROCEDURE = _vectors_svecf32_operator_mul,
LEFTARG = svector,
RIGHTARG = svector,
COMMUTATOR = *
);

CREATE OPERATOR * (
PROCEDURE = _vectors_veci8_operator_mul,
LEFTARG = veci8,
RIGHTARG = veci8,
COMMUTATOR = *
);

CREATE OPERATOR & (
PROCEDURE = _vectors_bvecf32_operator_and,
LEFTARG = bvector,
Expand Down
5 changes: 5 additions & 0 deletions tests/sqllogictest/fp16.slt
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,10 @@ SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <=> '[0.5,0.5,0.5]'::vecf16 l
----
10

query I
SELECT '[1,2,3]'::vecf16 * '[4,5,6]'::vecf16;
----
[4, 10, 18]

statement ok
DROP TABLE t;
5 changes: 5 additions & 0 deletions tests/sqllogictest/int8.slt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ SELECT to_veci8(5, 1, 0, '{0,1,2,0,0}');
----
[0, 1, 2, 0, 0]

query I
SELECT '[2,2,2]'::veci8 * '[2,2,2]'::veci8;
----
[4, 4, 4]

statement error Lengths of values and len are not matched.
SELECT to_veci8(5, 1, 0, '{0,1,2,0,0,0}');

Expand Down

0 comments on commit 09699cb

Please sign in to comment.