Skip to content

Commit

Permalink
Support adding vector index to the Database Model (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
JaySon-Huang authored and wd0517 committed Nov 14, 2024
1 parent f36f302 commit b77ea88
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 11 deletions.
9 changes: 5 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,11 @@ jobs:

name: vector-py${{ matrix.python-version }}_django${{ matrix.django-version }}
runs-on: ubuntu-latest
env:
TIDB_HOST: ${{ secrets.SERVRLESS_TEST_TIDB_HOST }}
TIDB_USER: ${{ secrets.SERVRLESS_TEST_TIDB_USER }}
TIDB_PASSWORD: ${{ secrets.SERVRLESS_TEST_TIDB_PASSWORD }}
services:
tidb:
image: wangdi4zm/tind:v8.4.0-vector-index
ports:
- 4000:4000
steps:
- name: Checkout
uses: actions/checkout@v3
Expand Down
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,17 @@ class Test(models.Model):
embedding = VectorField(dimensions=3)
```

You can also add an hnsw index when creating the table, for more information, please refer to the [documentation](https://docs.google.com/document/d/15eAO0xrvEd6_tTxW_zEko4CECwnnSwQg8GGrqK1Caiw).

```python
class Test(models.Model):
embedding = VectorField(dimensions=3)
class Meta:
indexes = [
VectorIndex(L2Distance("embedding"), name='idx_l2'),
]
```

#### Create a record

```python
Expand Down
96 changes: 91 additions & 5 deletions django_tidb/fields/vector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
from django.core import checks
from django import forms
from django.db.models import Field, FloatField, Func, Value
from django.db import models
from django.db.models import Field, FloatField, Func, Value, Index

MAX_DIM_LENGTH = 16000
MIN_DIM_LENGTH = 1
Expand Down Expand Up @@ -132,13 +133,98 @@ def _check_dimensions(self):
return []


class VectorIndex(Index):
"""
Example:
```python
from django.db import models
from django_tidb.fields.vector import VectorField, VectorIndex, CosineDistance
class Document(models.Model):
content = models.TextField()
embedding = VectorField(dimensions=3)
class Meta:
indexes = [
VectorIndex(CosineDistance("embedding"), name='idx_cos'),
]
# Create a document
Document.objects.create(
content="test content",
embedding=[1, 2, 3],
)
# Query with distance
Document.objects.alias(
distance=CosineDistance('embedding', [3, 1, 2])
).filter(distance__lt=5)
```
Note:
Creating a vector index will automatically set the "TiFlash replica" to 1 in TiDB.
If you want to use high-availability columnar storage feature, use raw SQL instead.
"""

def __init__(
self,
*expressions,
name,
) -> None:
super().__init__(*expressions, fields=(), name=name)

def create_sql(self, model, schema_editor, using="", **kwargs):
include = [
model._meta.get_field(field_name).column for field_name in self.include
]
index_expressions = []
for expression in self.expressions:
index_expression = models.indexes.IndexExpression(expression)
index_expression.set_wrapper_classes(schema_editor.connection)
index_expressions.append(index_expression)
expressions = models.indexes.ExpressionList(
*index_expressions
).resolve_expression(
models.sql.query.Query(model, alias_cols=False),
)
fields = None
col_suffixes = None
# TODO: remove the tiflash replica setting statement from sql_template
# after we support `ADD_TIFLASH_ON_DEMAND` in the `CREATE VECTOR INDEX ...`
sql_template = """ALTER TABLE %(table)s SET TIFLASH REPLICA 1;
CREATE VECTOR INDEX %(name)s ON %(table)s%(using)s (%(columns)s)%(extra)s"""
return schema_editor._create_index_sql(
model,
fields=fields,
name=self.name,
using=using,
db_tablespace=self.db_tablespace,
col_suffixes=col_suffixes,
sql=sql_template,
opclasses=self.opclasses,
condition=None,
include=include,
expressions=expressions,
**kwargs,
)


class DistanceBase(Func):
output_field = FloatField()

def __init__(self, expression, vector, **extra):
if not hasattr(vector, "resolve_expression"):
vector = Value(encode_vector(vector))
super().__init__(expression, vector, **extra)
def __init__(self, expression, vector=None, **extra):
"""
expression: the name of a field, or an expression returing a vector
vector: a vector to compare against
"""
expressions = [expression]
# When using the distance function as expression in the vector index
# statement, the `vector` is None
if vector is not None:
if not hasattr(vector, "resolve_expression"):
vector = Value(encode_vector(vector))
expressions.append(vector)
super().__init__(*expressions, **extra)


class L1Distance(DistanceBase):
Expand Down
18 changes: 17 additions & 1 deletion tests/tidb_vector/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from django.db import models

from django_tidb.fields.vector import VectorField
from django_tidb.fields.vector import (
VectorField,
VectorIndex,
CosineDistance,
L2Distance,
)


class Document(models.Model):
Expand All @@ -11,3 +16,14 @@ class Document(models.Model):
class DocumentExplicitDimension(models.Model):
content = models.TextField()
embedding = VectorField(dimensions=3)


class DocumentWithAnnIndex(models.Model):
content = models.TextField()
embedding = VectorField(dimensions=3)

class Meta:
indexes = [
VectorIndex(CosineDistance("embedding"), name="idx_cos"),
VectorIndex(L2Distance("embedding"), name="idx_l2"),
]
6 changes: 5 additions & 1 deletion tests/tidb_vector/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
NegativeInnerProduct,
)

from .models import Document, DocumentExplicitDimension
from .models import Document, DocumentExplicitDimension, DocumentWithAnnIndex


class TiDBVectorFieldTests(TestCase):
Expand Down Expand Up @@ -76,3 +76,7 @@ def test_negative_inner_product(self):

class TiDBVectorFieldExplicitDimensionTests(TiDBVectorFieldTests):
model = DocumentExplicitDimension


class TiDBVectorFieldWithAnnIndexTests(TiDBVectorFieldTests):
model = DocumentWithAnnIndex

0 comments on commit b77ea88

Please sign in to comment.