Skip to content

Commit

Permalink
fix: cosine distance (#231)
Browse files Browse the repository at this point in the history
* fix: cosine distance

Signed-off-by: Keming <kemingyang@tensorchord.ai>

* fix sq and pq

Signed-off-by: Keming <kemingyang@tensorchord.ai>

* fix doc

Signed-off-by: Keming <kemingyang@tensorchord.ai>

---------

Signed-off-by: Keming <kemingyang@tensorchord.ai>
  • Loading branch information
kemingy authored Jan 6, 2024
1 parent f30a10c commit 4aafd8c
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 15 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ We support three operators to calculate the distance between two vectors.

- `<->`: squared Euclidean distance, defined as $\Sigma (x_i - y_i) ^ 2$.
- `<#>`: negative dot product, defined as $- \Sigma x_iy_i$.
- `<=>`: negative cosine similarity, defined as $- \frac{\Sigma x_iy_i}{\sqrt{\Sigma x_i^2 \Sigma y_i^2}}$.
- `<=>`: cosine distance, defined as $1 - \frac{\Sigma x_iy_i}{\sqrt{\Sigma x_i^2 \Sigma y_i^2}}$.

```sql
-- call the distance function through operators
Expand All @@ -115,7 +115,7 @@ We support three operators to calculate the distance between two vectors.
SELECT '[1, 2, 3]'::vector <-> '[3, 2, 1]'::vector;
-- negative dot product
SELECT '[1, 2, 3]'::vector <#> '[3, 2, 1]'::vector;
-- negative cosine similarity
-- cosine distance
SELECT '[1, 2, 3]'::vector <=> '[3, 2, 1]'::vector;
```

Expand Down
2 changes: 1 addition & 1 deletion crates/service/src/instance/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub struct Metadata {
}

impl Metadata {
const VERSION: u64 = 2;
const VERSION: u64 = 3;
const SOFT_VERSION: u64 = 1;
}

Expand Down
12 changes: 6 additions & 6 deletions crates/service/src/prelude/global/f16_cos.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ impl G for F16Cos {
type L2 = F16L2;

fn distance(lhs: &[F16], rhs: &[F16]) -> F32 {
super::f16::cosine(lhs, rhs) * (-1.0)
F32(1.0) - super::f16::cosine(lhs, rhs)
}

fn elkan_k_means_normalize(vector: &mut [F16]) {
Expand Down Expand Up @@ -47,7 +47,7 @@ impl G for F16Cos {
x2 += _x * _x;
y2 += _y * _y;
}
xy / (x2 * y2).sqrt() * (-1.0)
F32(1.0) - xy / (x2 * y2).sqrt()
}

#[multiversion::multiversion(targets(
Expand All @@ -73,7 +73,7 @@ impl G for F16Cos {
x2 += _x * _x;
y2 += _y * _y;
}
xy / (x2 * y2).sqrt() * (-1.0)
F32(1.0) - xy / (x2 * y2).sqrt()
}

#[multiversion::multiversion(targets(
Expand Down Expand Up @@ -103,7 +103,7 @@ impl G for F16Cos {
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
F32(1.0) - xy / (x2 * y2).sqrt()
}

#[multiversion::multiversion(targets(
Expand Down Expand Up @@ -134,7 +134,7 @@ impl G for F16Cos {
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
F32(1.0) - xy / (x2 * y2).sqrt()
}

#[multiversion::multiversion(targets(
Expand Down Expand Up @@ -166,7 +166,7 @@ impl G for F16Cos {
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
F32(1.0) - xy / (x2 * y2).sqrt()
}
}

Expand Down
12 changes: 6 additions & 6 deletions crates/service/src/prelude/global/f32_cos.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ impl G for F32Cos {
type L2 = F32L2;

fn distance(lhs: &[F32], rhs: &[F32]) -> F32 {
cosine(lhs, rhs) * (-1.0)
F32(1.0) - cosine(lhs, rhs)
}

fn elkan_k_means_normalize(vector: &mut [F32]) {
Expand Down Expand Up @@ -47,7 +47,7 @@ impl G for F32Cos {
x2 += _x * _x;
y2 += _y * _y;
}
xy / (x2 * y2).sqrt() * (-1.0)
F32(1.0) - xy / (x2 * y2).sqrt()
}

#[multiversion::multiversion(targets(
Expand All @@ -73,7 +73,7 @@ impl G for F32Cos {
x2 += _x * _x;
y2 += _y * _y;
}
xy / (x2 * y2).sqrt() * (-1.0)
F32(1.0) - xy / (x2 * y2).sqrt()
}

#[multiversion::multiversion(targets(
Expand Down Expand Up @@ -103,7 +103,7 @@ impl G for F32Cos {
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
F32(1.0) - xy / (x2 * y2).sqrt()
}

#[multiversion::multiversion(targets(
Expand Down Expand Up @@ -134,7 +134,7 @@ impl G for F32Cos {
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
F32(1.0) - xy / (x2 * y2).sqrt()
}

#[multiversion::multiversion(targets(
Expand Down Expand Up @@ -166,7 +166,7 @@ impl G for F32Cos {
x2 += _x2;
y2 += _y2;
}
xy / (x2 * y2).sqrt() * (-1.0)
F32(1.0) - xy / (x2 * y2).sqrt()
}
}

Expand Down

0 comments on commit 4aafd8c

Please sign in to comment.