diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7b4ea9b..ce58bcb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -91,10 +91,11 @@ jobs: name: vector-py${{ matrix.python-version }}_django${{ matrix.django-version }} runs-on: ubuntu-latest - env: - TIDB_HOST: ${{ secrets.SERVRLESS_TEST_TIDB_HOST }} - TIDB_USER: ${{ secrets.SERVRLESS_TEST_TIDB_USER }} - TIDB_PASSWORD: ${{ secrets.SERVRLESS_TEST_TIDB_PASSWORD }} + services: + tidb: + image: wangdi4zm/tind:v8.4.0-vector-index + ports: + - 4000:4000 steps: - name: Checkout uses: actions/checkout@v3 diff --git a/README.md b/README.md index 64100cc..1379c55 100644 --- a/README.md +++ b/README.md @@ -164,11 +164,11 @@ You can also add an hnsw index when creating the table, for more information, pl ```python class Test(models.Model): - # Note: - # - Using comment to add hnsw index is a temporary solution. In the future it will use `CREATE INDEX` syntax. - # - Currently the hnsw index cannot be changed after the table has been created. - # - Only Django >= 4.2 supports `db_comment`. - embedding = VectorField(dimensions=3, db_comment="hnsw(distance=l2)") + embedding = VectorField(dimensions=3) + class Meta: + indexes = [ + VectorIndex(L2Distance("embedding"), name='idx_l2'), + ] ``` #### Create a record diff --git a/django_tidb/__init__.py b/django_tidb/__init__.py index 35bb02e..853f7e7 100644 --- a/django_tidb/__init__.py +++ b/django_tidb/__init__.py @@ -16,7 +16,7 @@ from .patch import monkey_patch -__version__ = "5.0.0" +__version__ = "5.0.1" monkey_patch() diff --git a/django_tidb/fields/vector.py b/django_tidb/fields/vector.py index 1448406..260b2e3 100644 --- a/django_tidb/fields/vector.py +++ b/django_tidb/fields/vector.py @@ -1,7 +1,8 @@ import numpy as np from django.core import checks from django import forms -from django.db.models import Field, FloatField, Func, Value +from django.db import models +from django.db.models import Field, FloatField, Func, Value, Index MAX_DIM_LENGTH = 16000 MIN_DIM_LENGTH = 1 @@ -132,13 +133,98 @@ def _check_dimensions(self): return [] +class VectorIndex(Index): + """ + Example: + ```python + from django.db import models + from django_tidb.fields.vector import VectorField, VectorIndex, CosineDistance + + class Document(models.Model): + content = models.TextField() + embedding = VectorField(dimensions=3) + class Meta: + indexes = [ + VectorIndex(CosineDistance("embedding"), name='idx_cos'), + ] + + # Create a document + Document.objects.create( + content="test content", + embedding=[1, 2, 3], + ) + + # Query with distance + Document.objects.alias( + distance=CosineDistance('embedding', [3, 1, 2]) + ).filter(distance__lt=5) + ``` + + Note: + Creating a vector index will automatically set the "TiFlash replica" to 1 in TiDB. + If you want to use high-availability columnar storage feature, use raw SQL instead. + + """ + + def __init__( + self, + *expressions, + name, + ) -> None: + super().__init__(*expressions, fields=(), name=name) + + def create_sql(self, model, schema_editor, using="", **kwargs): + include = [ + model._meta.get_field(field_name).column for field_name in self.include + ] + index_expressions = [] + for expression in self.expressions: + index_expression = models.indexes.IndexExpression(expression) + index_expression.set_wrapper_classes(schema_editor.connection) + index_expressions.append(index_expression) + expressions = models.indexes.ExpressionList( + *index_expressions + ).resolve_expression( + models.sql.query.Query(model, alias_cols=False), + ) + fields = None + col_suffixes = None + # TODO: remove the tiflash replica setting statement from sql_template + # after we support `ADD_TIFLASH_ON_DEMAND` in the `CREATE VECTOR INDEX ...` + sql_template = """ALTER TABLE %(table)s SET TIFLASH REPLICA 1; + CREATE VECTOR INDEX %(name)s ON %(table)s%(using)s (%(columns)s)%(extra)s""" + return schema_editor._create_index_sql( + model, + fields=fields, + name=self.name, + using=using, + db_tablespace=self.db_tablespace, + col_suffixes=col_suffixes, + sql=sql_template, + opclasses=self.opclasses, + condition=None, + include=include, + expressions=expressions, + **kwargs, + ) + + class DistanceBase(Func): output_field = FloatField() - def __init__(self, expression, vector, **extra): - if not hasattr(vector, "resolve_expression"): - vector = Value(encode_vector(vector)) - super().__init__(expression, vector, **extra) + def __init__(self, expression, vector=None, **extra): + """ + expression: the name of a field, or an expression returing a vector + vector: a vector to compare against + """ + expressions = [expression] + # When using the distance function as expression in the vector index + # statement, the `vector` is None + if vector is not None: + if not hasattr(vector, "resolve_expression"): + vector = Value(encode_vector(vector)) + expressions.append(vector) + super().__init__(*expressions, **extra) class L1Distance(DistanceBase): diff --git a/tests/tidb_vector/models.py b/tests/tidb_vector/models.py index 94d0b85..c34b87f 100644 --- a/tests/tidb_vector/models.py +++ b/tests/tidb_vector/models.py @@ -1,6 +1,11 @@ from django.db import models -from django_tidb.fields.vector import VectorField +from django_tidb.fields.vector import ( + VectorField, + VectorIndex, + CosineDistance, + L2Distance, +) class Document(models.Model): @@ -11,3 +16,14 @@ class Document(models.Model): class DocumentExplicitDimension(models.Model): content = models.TextField() embedding = VectorField(dimensions=3) + + +class DocumentWithAnnIndex(models.Model): + content = models.TextField() + embedding = VectorField(dimensions=3) + + class Meta: + indexes = [ + VectorIndex(CosineDistance("embedding"), name="idx_cos"), + VectorIndex(L2Distance("embedding"), name="idx_l2"), + ] diff --git a/tests/tidb_vector/test_vector.py b/tests/tidb_vector/test_vector.py index 7af70d3..f931de2 100644 --- a/tests/tidb_vector/test_vector.py +++ b/tests/tidb_vector/test_vector.py @@ -9,7 +9,7 @@ NegativeInnerProduct, ) -from .models import Document, DocumentExplicitDimension +from .models import Document, DocumentExplicitDimension, DocumentWithAnnIndex class TiDBVectorFieldTests(TestCase): @@ -76,3 +76,7 @@ def test_negative_inner_product(self): class TiDBVectorFieldExplicitDimensionTests(TiDBVectorFieldTests): model = DocumentExplicitDimension + + +class TiDBVectorFieldWithAnnIndexTests(TiDBVectorFieldTests): + model = DocumentWithAnnIndex