Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support adding vector index to the Database Model #67

Merged
merged 6 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@ jobs:

name: vector-py${{ matrix.python-version }}_django${{ matrix.django-version }}
runs-on: ubuntu-latest
env:
TIDB_HOST: ${{ secrets.SERVRLESS_TEST_TIDB_HOST }}
TIDB_USER: ${{ secrets.SERVRLESS_TEST_TIDB_USER }}
TIDB_PASSWORD: ${{ secrets.SERVRLESS_TEST_TIDB_PASSWORD }}
services:
tidb:
image: wangdi4zm/tind:v8.4.0-vector-index
ports:
- 4000:4000
steps:
- name: Checkout
uses: actions/checkout@v3
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,11 @@ You can also add an hnsw index when creating the table, for more information, pl

```python
class Test(models.Model):
# Note:
# - Using comment to add hnsw index is a temporary solution. In the future it will use `CREATE INDEX` syntax.
# - Currently the hnsw index cannot be changed after the table has been created.
# - Only Django >= 4.2 supports `db_comment`.
embedding = VectorField(dimensions=3, db_comment="hnsw(distance=l2)")
embedding = VectorField(dimensions=3)
class Meta:
indexes = [
VectorIndex(L2Distance("embedding"), name='idx_l2'),
]
```

#### Create a record
Expand Down
2 changes: 1 addition & 1 deletion django_tidb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from .patch import monkey_patch

__version__ = "5.0.0"
__version__ = "5.0.1"


monkey_patch()
96 changes: 91 additions & 5 deletions django_tidb/fields/vector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
from django.core import checks
from django import forms
from django.db.models import Field, FloatField, Func, Value
from django.db import models
from django.db.models import Field, FloatField, Func, Value, Index

MAX_DIM_LENGTH = 16000
MIN_DIM_LENGTH = 1
Expand Down Expand Up @@ -132,13 +133,98 @@ def _check_dimensions(self):
return []


class VectorIndex(Index):
"""
Example:
```python
from django.db import models
from django_tidb.fields.vector import VectorField, VectorIndex, CosineDistance

class Document(models.Model):
content = models.TextField()
embedding = VectorField(dimensions=3)
class Meta:
indexes = [
VectorIndex(CosineDistance("embedding"), name='idx_cos'),
]

# Create a document
Document.objects.create(
content="test content",
embedding=[1, 2, 3],
)

# Query with distance
Document.objects.alias(
distance=CosineDistance('embedding', [3, 1, 2])
).filter(distance__lt=5)
```

Note:
Creating a vector index will automatically set the "TiFlash replica" to 1 in TiDB.
If you want to use high-availability columnar storage feature, use raw SQL instead.

"""

def __init__(
self,
*expressions,
name,
) -> None:
super().__init__(*expressions, fields=(), name=name)

def create_sql(self, model, schema_editor, using="", **kwargs):
include = [
model._meta.get_field(field_name).column for field_name in self.include
]
index_expressions = []
for expression in self.expressions:
index_expression = models.indexes.IndexExpression(expression)
index_expression.set_wrapper_classes(schema_editor.connection)
index_expressions.append(index_expression)
expressions = models.indexes.ExpressionList(
*index_expressions
).resolve_expression(
models.sql.query.Query(model, alias_cols=False),
)
fields = None
col_suffixes = None
# TODO: remove the tiflash replica setting statement from sql_template
# after we support `ADD_TIFLASH_ON_DEMAND` in the `CREATE VECTOR INDEX ...`
sql_template = """ALTER TABLE %(table)s SET TIFLASH REPLICA 1;
CREATE VECTOR INDEX %(name)s ON %(table)s%(using)s (%(columns)s)%(extra)s"""
return schema_editor._create_index_sql(
model,
fields=fields,
name=self.name,
using=using,
db_tablespace=self.db_tablespace,
col_suffixes=col_suffixes,
sql=sql_template,
opclasses=self.opclasses,
condition=None,
include=include,
expressions=expressions,
**kwargs,
)


class DistanceBase(Func):
output_field = FloatField()

def __init__(self, expression, vector, **extra):
if not hasattr(vector, "resolve_expression"):
vector = Value(encode_vector(vector))
super().__init__(expression, vector, **extra)
def __init__(self, expression, vector=None, **extra):
"""
expression: the name of a field, or an expression returing a vector
vector: a vector to compare against
"""
expressions = [expression]
# When using the distance function as expression in the vector index
# statement, the `vector` is None
if vector is not None:
if not hasattr(vector, "resolve_expression"):
vector = Value(encode_vector(vector))
expressions.append(vector)
super().__init__(*expressions, **extra)


class L1Distance(DistanceBase):
Expand Down
18 changes: 17 additions & 1 deletion tests/tidb_vector/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from django.db import models

from django_tidb.fields.vector import VectorField
from django_tidb.fields.vector import (
VectorField,
VectorIndex,
CosineDistance,
L2Distance,
)


class Document(models.Model):
Expand All @@ -11,3 +16,14 @@ class Document(models.Model):
class DocumentExplicitDimension(models.Model):
content = models.TextField()
embedding = VectorField(dimensions=3)


class DocumentWithAnnIndex(models.Model):
content = models.TextField()
embedding = VectorField(dimensions=3)

class Meta:
indexes = [
VectorIndex(CosineDistance("embedding"), name="idx_cos"),
VectorIndex(L2Distance("embedding"), name="idx_l2"),
]
6 changes: 5 additions & 1 deletion tests/tidb_vector/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
NegativeInnerProduct,
)

from .models import Document, DocumentExplicitDimension
from .models import Document, DocumentExplicitDimension, DocumentWithAnnIndex


class TiDBVectorFieldTests(TestCase):
Expand Down Expand Up @@ -76,3 +76,7 @@ def test_negative_inner_product(self):

class TiDBVectorFieldExplicitDimensionTests(TiDBVectorFieldTests):
model = DocumentExplicitDimension


class TiDBVectorFieldWithAnnIndexTests(TiDBVectorFieldTests):
model = DocumentWithAnnIndex
Loading