diff --git a/.github/actions/python/action.yaml b/.github/actions/python/action.yaml index 592475ba108..86fa97f1c6f 100644 --- a/.github/actions/python/action.yaml +++ b/.github/actions/python/action.yaml @@ -4,19 +4,59 @@ inputs: python-version: description: "Python version to use" required: false - default: "3.8" + default: "3.9" runs: using: "composite" steps: + - name: Set up Python 3.9 for protos + uses: actions/setup-python@v5 + with: + python-version: "3.9" + cache: "pip" + cache-dependency-path: "requirements*.txt" + - name: Install proto dependencies + run: | + python -m pip install grpcio==1.58.0 grpcio-tools==1.58.0 + shell: bash + - name: Generate Proto Files + if: runner.os != 'Windows' + run: make -C idl proto_python + shell: bash + - name: Generate Proto Files (Windows) + if: runner.os == 'Windows' + run: cd idl && make proto_python + shell: cmd + - name: Uninstall proto dependencies + run: | + python -m pip uninstall -y grpcio grpcio-tools + shell: bash - name: Set up Python ${{ inputs.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ inputs.python-version }} cache: "pip" cache-dependency-path: "requirements*.txt" - - name: Install test dependencies - run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt + - name: Install dependencies + run: | + python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt shell: bash + - name: Install protobuf compiler (protoc) - Linux + if: runner.os != 'Windows' + run: | + sudo apt-get update + sudo apt-get install -y wget unzip + wget https://github.com/protocolbuffers/protobuf/releases/download/v28.2/protoc-28.2-linux-x86_64.zip + sudo unzip protoc-28.2-linux-x86_64.zip -d /usr/local/ + sudo rm protoc-28.2-linux-x86_64.zip + shell: bash + - name: Install protobuf compiler (protoc) - Windows + if: runner.os == 'Windows' + run: | + Invoke-WebRequest -Uri https://github.com/protocolbuffers/protobuf/releases/download/v28.2/protoc-28.2-win64.zip -OutFile protoc.zip + Expand-Archive -Path protoc.zip -DestinationPath C:\protoc + echo "C:\protoc\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + Remove-Item protoc.zip + shell: pwsh - name: Upgrade SQLite run: python bin/windows_upgrade_sqlite.py shell: bash diff --git a/.github/workflows/_python-tests.yml b/.github/workflows/_python-tests.yml index 1d8b4f94c84..60f015e8743 100644 --- a/.github/workflows/_python-tests.yml +++ b/.github/workflows/_python-tests.yml @@ -6,7 +6,7 @@ on: python_versions: description: 'Python versions to test (as json array)' required: false - default: '["3.8"]' + default: '["3.9"]' type: string property_testing_preset: description: 'Property testing preset' @@ -31,6 +31,7 @@ jobs: "chromadb/test/property/test_embeddings.py", "chromadb/test/property/test_filtering.py", "chromadb/test/property/test_persist.py", + "chromadb/test/property/test_sysdb.py", "chromadb/test/property/test_restart_persist.py"] include: - test-globs: "chromadb/test/property/test_embeddings.py" @@ -61,11 +62,12 @@ jobs: "chromadb/test/test_cli.py", "chromadb/test/auth/test_simple_rbac_authz.py", "chromadb/test/property/test_collections.py", - "chromadb/test/property/test_collections_with_database_tenant.py", + "chromadb/test/property/test_collections_with_database_tenant.py", "chromadb/test/property/test_cross_version_persist.py", "chromadb/test/property/test_embeddings.py", "chromadb/test/property/test_filtering.py", - "chromadb/test/property/test_persist.py"] + "chromadb/test/property/test_persist.py", + "chromadb/test/property/test_sysdb.py"] include: - platform: depot-ubuntu-22.04 env-file: compose-env.linux @@ -92,12 +94,14 @@ jobs: platform: ["depot-ubuntu-22.04-16"] test-globs: ["chromadb/test/db/test_system.py", "chromadb/test/api/test_collection.py", + "chromadb/test/api/test_limit_offset.py", "chromadb/test/property/test_collections.py", "chromadb/test/property/test_add.py", "chromadb/test/property/test_filtering.py", "chromadb/test/property/test_embeddings.py", "chromadb/test/property/test_collections_with_database_tenant.py", "chromadb/test/property/test_collections_with_database_tenant_overwrite.py", + "chromadb/test/property/test_sysdb.py", "chromadb/test/ingest/test_producer_consumer.py", "chromadb/test/segment/distributed/test_memberlist_provider.py", "chromadb/test/test_logservice.py", diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index dade99b3462..93cdca3d939 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -114,7 +114,7 @@ jobs: uses: actions/checkout@v4 - uses: ./.github/actions/python with: - python-version: "3.12" + python-version: "3.11" - name: Setup Rust uses: ./.github/actions/rust - name: Run pre-commit diff --git a/.github/workflows/release-chromadb.yml b/.github/workflows/release-chromadb.yml index 52152ef187f..bce16af60d3 100644 --- a/.github/workflows/release-chromadb.yml +++ b/.github/workflows/release-chromadb.yml @@ -40,7 +40,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.8' + python-version: '3.9' - name: Install setuptools_scm run: python -m pip install setuptools_scm - name: Get Release Version @@ -50,7 +50,7 @@ jobs: python-tests: uses: ./.github/workflows/_python-tests.yml with: - python_versions: '["3.8", "3.9", "3.10", "3.11", "3.12"]' + python_versions: '["3.9", "3.10", "3.11", "3.12"]' property_testing_preset: 'normal' javascript-client-tests: diff --git a/Cargo.lock b/Cargo.lock index 9e862373306..4b26a0807a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1364,8 +1364,8 @@ dependencies = [ [[package]] name = "chromadb" -version = "1.0.2" -source = "git+https://github.com/rescrv/chromadb-rs?rev=e364e35c34c660d4e8e862436ea600ddc2f46a1e#e364e35c34c660d4e8e862436ea600ddc2f46a1e" +version = "1.1.0" +source = "git+https://github.com/rescrv/chromadb-rs?rev=e9a8fb2e8b8edf7acfb1accf10166720a6bbbd33#e9a8fb2e8b8edf7acfb1accf10166720a6bbbd33" dependencies = [ "anyhow", "async-trait", @@ -6792,6 +6792,7 @@ dependencies = [ "figment", "flatbuffers", "futures", + "indicatif", "k8s-openapi", "kube", "murmur3", diff --git a/Dockerfile b/Dockerfile index c3fc98a8083..2084b4d6a54 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,14 +1,26 @@ FROM python:3.11-slim-bookworm AS builder ARG REBUILD_HNSWLIB +ARG PROTOBUF_VERSION=28.2 RUN apt-get update --fix-missing && apt-get install -y --fix-missing \ build-essential \ gcc \ g++ \ cmake \ - autoconf && \ + autoconf \ + python3-dev \ + unzip \ + curl \ + make && \ rm -rf /var/lib/apt/lists/* && \ mkdir /install +# Install specific Protobuf compiler (v28.2) +RUN curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOBUF_VERSION}/protoc-${PROTOBUF_VERSION}-linux-x86_64.zip && \ + unzip protoc-${PROTOBUF_VERSION}-linux-x86_64.zip -d /usr/local/ && \ + rm protoc-${PROTOBUF_VERSION}-linux-x86_64.zip && \ + chmod +x /usr/local/bin/protoc && \ + protoc --version # Verify installed version + WORKDIR /install COPY ./requirements.txt requirements.txt @@ -16,19 +28,32 @@ COPY ./requirements.txt requirements.txt RUN --mount=type=cache,target=/root/.cache/pip pip install --upgrade --prefix="/install" -r requirements.txt RUN --mount=type=cache,target=/root/.cache/pip if [ "$REBUILD_HNSWLIB" = "true" ]; then pip install --no-binary :all: --force-reinstall --prefix="/install" chroma-hnswlib; fi +# Install gRPC tools for Python with fixed version +RUN pip install grpcio==1.58.0 grpcio-tools==1.58.0 + +# Copy source files to build Protobufs +COPY ./ /chroma + +# Generate Protobufs +WORKDIR /chroma +RUN make -C idl proto_python + FROM python:3.11-slim-bookworm AS final +# Create working directory RUN mkdir /chroma WORKDIR /chroma +# Copy entrypoint COPY ./bin/docker_entrypoint.sh /docker_entrypoint.sh RUN apt-get update --fix-missing && apt-get install -y curl && \ chmod +x /docker_entrypoint.sh && \ rm -rf /var/lib/apt/lists/* +# Copy built dependencies and generated Protobufs COPY --from=builder /install /usr/local -COPY ./ /chroma +COPY --from=builder /chroma /chroma ENV CHROMA_HOST_ADDR="0.0.0.0" ENV CHROMA_HOST_PORT=8000 diff --git a/Tiltfile b/Tiltfile index ddc48167b28..e92366580e9 100644 --- a/Tiltfile +++ b/Tiltfile @@ -145,4 +145,4 @@ k8s_resource('prometheus', resource_deps=['k8s_setup'], labels=["observability"] k8s_resource('otel-collector', resource_deps=['k8s_setup'], labels=["observability"]) # Local S3 -k8s_resource('minio-deployment', resource_deps=['k8s_setup'], labels=["debug"], port_forwards='9000:9000') +k8s_resource('minio-deployment', resource_deps=['k8s_setup'], labels=["debug"], port_forwards=['9000:9000', '9005:9005']) diff --git a/chromadb/__init__.py b/chromadb/__init__.py index c8e968c632a..ea079ba9b8e 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -45,7 +45,7 @@ __settings = Settings() -__version__ = "0.5.23" +__version__ = "0.6.0" # Workaround to deal with Colab's old sqlite3 version diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 045f16507f6..b25d6550182 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -580,7 +580,7 @@ def _get( } ) - coll = self._get_collection(collection_id) + scan = self._scan(collection_id) # TODO: Replace with unified validation if where is not None: @@ -619,7 +619,7 @@ def _get( return self._executor.get( GetPlan( - Scan(coll), + scan, Filter(ids, where, where_document), Limit(offset or 0, limit), Projection( @@ -676,7 +676,7 @@ def _delete( """ ) - coll = self._get_collection(collection_id) + scan = self._scan(collection_id) self._quota_enforcer.enforce( action=Action.DELETE, @@ -690,7 +690,7 @@ def _delete( if (where or where_document) or not ids: ids_to_delete = self._executor.get( - GetPlan(Scan(coll), Filter(ids, where, where_document)) + GetPlan(scan, Filter(ids, where, where_document)) )["ids"] else: ids_to_delete = ids @@ -701,7 +701,7 @@ def _delete( records_to_submit = list( _records(operation=t.Operation.DELETE, ids=ids_to_delete) ) - self._validate_embedding_record_set(coll, records_to_submit) + self._validate_embedding_record_set(scan.collection, records_to_submit) self._producer.submit_embeddings(collection_id, records_to_submit) self._product_telemetry_client.capture( @@ -726,8 +726,7 @@ def _count( database: str = DEFAULT_DATABASE, ) -> int: add_attributes_to_current_span({"collection_id": str(collection_id)}) - coll = self._get_collection(collection_id) - return self._executor.count(CountPlan(Scan(coll))) + return self._executor.count(CountPlan(self._scan(collection_id))) @trace_method("SegmentAPI._query", OpenTelemetryGranularity.OPERATION) # We retry on version mismatch errors because the version of the collection @@ -785,9 +784,9 @@ def _query( if where_document is not None: validate_where_document(where_document) - coll = self._get_collection(collection_id) + scan = self._scan(collection_id) for embedding in query_embeddings: - self._validate_dimension(coll, len(embedding), update=False) + self._validate_dimension(scan.collection, len(embedding), update=False) self._quota_enforcer.enforce( action=Action.QUERY, @@ -800,7 +799,7 @@ def _query( return self._executor.knn( KNNPlan( - Scan(coll), + scan, KNN(query_embeddings, n_results), Filter(None, where, where_document), Projection( @@ -893,6 +892,21 @@ def _get_collection(self, collection_id: UUID) -> t.Collection: ) return collections[0] + @trace_method("SegmentAPI._scan", OpenTelemetryGranularity.ALL) + def _scan(self, collection_id: UUID) -> Scan: + collection_and_segments = self._sysdb.get_collection_with_segments(collection_id) + # For now collection should have exactly one segment per scope: + # - Local scopes: vector, metadata + # - Distributed scopes: vector, metadata, record + scope_to_segment = {segment["scope"]: segment for segment in collection_and_segments["segments"]} + return Scan( + collection=collection_and_segments["collection"], + knn=scope_to_segment[t.SegmentScope.VECTOR], + metadata=scope_to_segment[t.SegmentScope.METADATA], + # Local chroma do not have record segment, and this is not used by the local executor + record=scope_to_segment.get(t.SegmentScope.RECORD, None), # type: ignore[arg-type] + ) + def _records( operation: t.Operation, diff --git a/chromadb/db/impl/grpc/client.py b/chromadb/db/impl/grpc/client.py index b663f873f8a..ea89ef24424 100644 --- a/chromadb/db/impl/grpc/client.py +++ b/chromadb/db/impl/grpc/client.py @@ -22,6 +22,8 @@ DeleteSegmentRequest, GetCollectionsRequest, GetCollectionsResponse, + GetCollectionWithSegmentsRequest, + GetCollectionWithSegmentsResponse, GetDatabaseRequest, GetSegmentsRequest, GetTenantRequest, @@ -33,6 +35,7 @@ from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor from chromadb.types import ( Collection, + CollectionAndSegments, Database, Metadata, OptionalArgument, @@ -363,6 +366,23 @@ def get_collections( ) raise InternalError() + @overrides + def get_collection_with_segments(self, collection_id: UUID) -> CollectionAndSegments: + try: + request = GetCollectionWithSegmentsRequest(id=collection_id.hex) + response: GetCollectionWithSegmentsResponse = self._sys_db_stub.GetCollectionWithSegments(request) + return CollectionAndSegments( + collection=from_proto_collection(response.collection), + segments=[from_proto_segment(segment) for segment in response.segments] + ) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + raise NotFoundError() + logger.error( + f"Failed to get collection {collection_id} and its segments due to error: {e}" + ) + raise InternalError() + @overrides def update_collection( self, diff --git a/chromadb/db/impl/grpc/server.py b/chromadb/db/impl/grpc/server.py index 8edbc0eeccb..b6948ee9615 100644 --- a/chromadb/db/impl/grpc/server.py +++ b/chromadb/db/impl/grpc/server.py @@ -28,6 +28,8 @@ DeleteSegmentResponse, GetCollectionsRequest, GetCollectionsResponse, + GetCollectionWithSegmentsRequest, + GetCollectionWithSegmentsResponse, GetDatabaseRequest, GetDatabaseResponse, GetSegmentsRequest, @@ -46,7 +48,7 @@ ) import grpc from google.protobuf.empty_pb2 import Empty -from chromadb.types import Collection, Metadata, Segment +from chromadb.types import Collection, Metadata, Segment, SegmentScope class GrpcMockSysDB(SysDBServicer, Component): @@ -370,6 +372,30 @@ def GetCollections( ] ) + @overrides(check_signature=False) + def GetCollectionWithSegments( + self, request: GetCollectionWithSegmentsRequest, context: grpc.ServicerContext + ) -> GetCollectionWithSegmentsResponse: + allCollections = {} + for tenant, databases in self._tenants_to_databases_to_collections.items(): + for database, collections in databases.items(): + allCollections.update(collections) + print( + f"Tenant: {tenant}, Database: {database}, Collections: {collections}" + ) + collection = allCollections.get(request.id, None) + if collection is None: + context.abort(grpc.StatusCode.NOT_FOUND, f"Collection with id {request.id} not found") + collection_unique_key = f"{collection.tenant}:{collection.database}:{request.id}" + segments = [self._segments[id] for id in self._collection_to_segments[collection_unique_key]] + if {segment["scope"] for segment in segments} != {SegmentScope.METADATA, SegmentScope.RECORD, SegmentScope.VECTOR}: + context.abort(grpc.StatusCode.INTERNAL, f"Incomplete segments for collection {collection}: {segments}") + + return GetCollectionWithSegmentsResponse( + collection=to_proto_collection(collection), + segments=[to_proto_segment(segment) for segment in segments] + ) + @overrides(check_signature=False) def UpdateCollection( self, request: UpdateCollectionRequest, context: grpc.ServicerContext diff --git a/chromadb/db/mixins/sysdb.py b/chromadb/db/mixins/sysdb.py index a10042cefe9..8c7f2b843ee 100644 --- a/chromadb/db/mixins/sysdb.py +++ b/chromadb/db/mixins/sysdb.py @@ -15,7 +15,7 @@ from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System from chromadb.db.base import Cursor, SqlDB, ParameterValue, get_sql from chromadb.db.system import SysDB -from chromadb.errors import NotFoundError, UniqueConstraintError +from chromadb.errors import InvalidCollectionException, NotFoundError, UniqueConstraintError from chromadb.telemetry.opentelemetry import ( add_attributes_to_current_span, OpenTelemetryClient, @@ -24,6 +24,7 @@ ) from chromadb.ingest import Producer from chromadb.types import ( + CollectionAndSegments, Database, OptionalArgument, Segment, @@ -367,6 +368,7 @@ def get_segments( scope=scope, collection=collection, metadata=metadata, + file_paths={}, ) ) @@ -488,6 +490,18 @@ def get_collections( return collections + @override + def get_collection_with_segments(self, collection_id: UUID) -> CollectionAndSegments: + collections = self.get_collections(id=collection_id) + if len(collections) == 0: + raise InvalidCollectionException( + f"Collection {collection_id} does not exist." + ) + return CollectionAndSegments( + collection=collections[0], + segments=self.get_segments(collection=collection_id), + ) + @trace_method("SqlSysDB.delete_segment", OpenTelemetryGranularity.ALL) @override def delete_segment(self, collection: UUID, id: UUID) -> None: diff --git a/chromadb/db/system.py b/chromadb/db/system.py index 11a385155a9..ec440e836ee 100644 --- a/chromadb/db/system.py +++ b/chromadb/db/system.py @@ -4,6 +4,7 @@ from chromadb.api.configuration import CollectionConfigurationInternal from chromadb.types import ( Collection, + CollectionAndSegments, Database, Tenant, Metadata, @@ -128,6 +129,16 @@ def get_collections( """Find collections by id or name. If name is provided, tenant and database must also be provided.""" pass + @abstractmethod + def get_collection_with_segments( + self, + collection_id: UUID + ) -> CollectionAndSegments: + """Get a consistent snapshot of a collection by id. This will return a collection with segment + information that matches the collection version and log position. + """ + pass + @abstractmethod def update_collection( self, diff --git a/chromadb/execution/executor/distributed.py b/chromadb/execution/executor/distributed.py index a38dec66135..3cf5c591c77 100644 --- a/chromadb/execution/executor/distributed.py +++ b/chromadb/execution/executor/distributed.py @@ -7,7 +7,7 @@ from chromadb.config import System from chromadb.errors import VersionMismatchError from chromadb.execution.executor.abstract import Executor -from chromadb.execution.expression.operator import Scan, SegmentScan +from chromadb.execution.expression.operator import Scan from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan from chromadb.proto import convert @@ -15,7 +15,6 @@ from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor from chromadb.segment.impl.manager.distributed import DistributedSegmentManager from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor -from chromadb.types import SegmentScope def _clean_metadata(metadata: Optional[Metadata]) -> Optional[Metadata]: @@ -55,30 +54,18 @@ def __init__(self, system: System): @overrides def count(self, plan: CountPlan) -> int: executor = self._grpc_executuor_stub(plan.scan) - plan.scan = self._segment_scan(plan.scan) try: count_result = executor.Count(convert.to_proto_count_plan(plan)) except grpc.RpcError as rpc_error: - if ( - rpc_error.code() == grpc.StatusCode.INTERNAL - and "version mismatch" in rpc_error.details() - ): - raise VersionMismatchError() raise rpc_error return convert.from_proto_count_result(count_result) @overrides def get(self, plan: GetPlan) -> GetResult: executor = self._grpc_executuor_stub(plan.scan) - plan.scan = self._segment_scan(plan.scan) try: get_result = executor.Get(convert.to_proto_get_plan(plan)) except grpc.RpcError as rpc_error: - if ( - rpc_error.code() == grpc.StatusCode.INTERNAL - and "version mismatch" in rpc_error.details() - ): - raise VersionMismatchError() raise rpc_error records = convert.from_proto_get_result(get_result) @@ -118,15 +105,9 @@ def get(self, plan: GetPlan) -> GetResult: @overrides def knn(self, plan: KNNPlan) -> QueryResult: executor = self._grpc_executuor_stub(plan.scan) - plan.scan = self._segment_scan(plan.scan) try: knn_result = executor.KNN(convert.to_proto_knn_plan(plan)) except grpc.RpcError as rpc_error: - if ( - rpc_error.code() == grpc.StatusCode.INTERNAL - and "version mismatch" in rpc_error.details() - ): - raise VersionMismatchError() raise rpc_error results = convert.from_proto_knn_batch_result(knn_result) @@ -181,19 +162,10 @@ def knn(self, plan: KNNPlan) -> QueryResult: included=plan.projection.included, ) - def _segment_scan(self, scan: Scan) -> SegmentScan: - knn = self._manager.get_segment(scan.collection.id, SegmentScope.VECTOR) - metadata = self._manager.get_segment(scan.collection.id, SegmentScope.METADATA) - record = self._manager.get_segment(scan.collection.id, SegmentScope.RECORD) - return SegmentScan( - collection=scan.collection, - knn_id=knn["id"], - metadata_id=metadata["id"], - record_id=record["id"], - ) - def _grpc_executuor_stub(self, scan: Scan) -> QueryExecutorStub: - grpc_url = self._manager.get_endpoint(scan.collection.id) + # Since grpc endpoint is endpoint is determined by collection uuid, + # the endpoint should be the same for all segments of the same collection + grpc_url = self._manager.get_endpoint(scan.record) if grpc_url not in self._grpc_stub_pool: channel = grpc.insecure_channel(grpc_url) interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()] diff --git a/chromadb/execution/expression/operator.py b/chromadb/execution/expression/operator.py index 2f56df78041..01dff0bb84f 100644 --- a/chromadb/execution/expression/operator.py +++ b/chromadb/execution/expression/operator.py @@ -1,14 +1,16 @@ from dataclasses import dataclass from typing import Optional -from uuid import UUID from chromadb.api.types import Embeddings, IDs, Include, IncludeEnum -from chromadb.types import Collection, RequestVersionContext, Where, WhereDocument +from chromadb.types import Collection, RequestVersionContext, Segment, Where, WhereDocument @dataclass class Scan: collection: Collection + knn: Segment + metadata: Segment + record: Segment @property def version(self) -> RequestVersionContext: @@ -17,14 +19,6 @@ def version(self) -> RequestVersionContext: log_position=self.collection.log_position, ) - -@dataclass -class SegmentScan(Scan): - knn_id: UUID - metadata_id: UUID - record_id: UUID - - @dataclass class Filter: user_ids: Optional[IDs] = None diff --git a/chromadb/proto/chroma_pb2.py b/chromadb/proto/chroma_pb2.py deleted file mode 100644 index 23855636914..00000000000 --- a/chromadb/proto/chroma_pb2.py +++ /dev/null @@ -1,127 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: chromadb/proto/chroma.proto -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma\"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding\"\x1a\n\tFilePaths\x12\r\n\x05paths\x18\x01 \x03(\t\"\x91\x02\n\x07Segment\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12#\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScope\x12\x12\n\ncollection\x18\x05 \x01(\t\x12-\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x32\n\nfile_paths\x18\x07 \x03(\x0b\x32\x1e.chroma.Segment.FilePathsEntry\x1a\x43\n\x0e\x46ilePathsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.chroma.FilePaths:\x02\x38\x01\x42\x0b\n\t_metadata\"\xf1\x01\n\nCollection\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x1e\n\x16\x63onfiguration_json_str\x18\x03 \x01(\t\x12-\n\x08metadata\x18\x04 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x05 \x01(\x05H\x01\x88\x01\x01\x12\x0e\n\x06tenant\x18\x06 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x07 \x01(\t\x12\x14\n\x0clog_position\x18\x08 \x01(\x03\x12\x0f\n\x07version\x18\t \x01(\x05\x42\x0b\n\t_metadataB\x0c\n\n_dimension\"4\n\x08\x44\x61tabase\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0e\n\x06tenant\x18\x03 \x01(\t\"\x16\n\x06Tenant\x12\x0c\n\x04name\x18\x01 \x01(\t\"x\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x12\x14\n\nbool_value\x18\x04 \x01(\x08H\x00\x42\x07\n\x05value\"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01\"\xaf\x01\n\x0fOperationRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata\"I\n\x15RequestVersionContext\x12\x1a\n\x12\x63ollection_version\x18\x01 \x01(\r\x12\x14\n\x0clog_position\x18\x02 \x01(\x04\"x\n\x13\x43ountRecordsRequest\x12\x12\n\nsegment_id\x18\x01 \x01(\t\x12\x15\n\rcollection_id\x18\x02 \x01(\t\x12\x36\n\x0fversion_context\x18\x03 \x01(\x0b\x32\x1d.chroma.RequestVersionContext\"%\n\x14\x43ountRecordsResponse\x12\r\n\x05\x63ount\x18\x01 \x01(\r\"\xc9\x02\n\x14QueryMetadataRequest\x12\x12\n\nsegment_id\x18\x01 \x01(\t\x12\x1c\n\x05where\x18\x02 \x01(\x0b\x32\r.chroma.Where\x12-\n\x0ewhere_document\x18\x03 \x01(\x0b\x32\x15.chroma.WhereDocument\x12!\n\x03ids\x18\x04 \x01(\x0b\x32\x0f.chroma.UserIdsH\x00\x88\x01\x01\x12\x12\n\x05limit\x18\x05 \x01(\rH\x01\x88\x01\x01\x12\x13\n\x06offset\x18\x06 \x01(\rH\x02\x88\x01\x01\x12\x15\n\rcollection_id\x18\x07 \x01(\t\x12\x18\n\x10include_metadata\x18\x08 \x01(\x08\x12\x36\n\x0fversion_context\x18\t \x01(\x0b\x32\x1d.chroma.RequestVersionContextB\x06\n\x04_idsB\x08\n\x06_limitB\t\n\x07_offset\"I\n\x15QueryMetadataResponse\x12\x30\n\x07records\x18\x01 \x03(\x0b\x32\x1f.chroma.MetadataEmbeddingRecord\"O\n\x17MetadataEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12(\n\x08metadata\x18\x02 \x01(\x0b\x32\x16.chroma.UpdateMetadata\"\x16\n\x07UserIds\x12\x0b\n\x03ids\x18\x01 \x03(\t\"\x83\x01\n\rWhereDocument\x12-\n\x06\x64irect\x18\x01 \x01(\x0b\x32\x1b.chroma.DirectWhereDocumentH\x00\x12\x31\n\x08\x63hildren\x18\x02 \x01(\x0b\x32\x1d.chroma.WhereDocumentChildrenH\x00\x42\x10\n\x0ewhere_document\"X\n\x13\x44irectWhereDocument\x12\x10\n\x08\x64ocument\x18\x01 \x01(\t\x12/\n\x08operator\x18\x02 \x01(\x0e\x32\x1d.chroma.WhereDocumentOperator\"k\n\x15WhereDocumentChildren\x12\'\n\x08\x63hildren\x18\x01 \x03(\x0b\x32\x15.chroma.WhereDocument\x12)\n\x08operator\x18\x02 \x01(\x0e\x32\x17.chroma.BooleanOperator\"r\n\x05Where\x12\x35\n\x11\x64irect_comparison\x18\x01 \x01(\x0b\x32\x18.chroma.DirectComparisonH\x00\x12)\n\x08\x63hildren\x18\x02 \x01(\x0b\x32\x15.chroma.WhereChildrenH\x00\x42\x07\n\x05where\"\x91\x04\n\x10\x44irectComparison\x12\x0b\n\x03key\x18\x01 \x01(\t\x12?\n\x15single_string_operand\x18\x02 \x01(\x0b\x32\x1e.chroma.SingleStringComparisonH\x00\x12;\n\x13string_list_operand\x18\x03 \x01(\x0b\x32\x1c.chroma.StringListComparisonH\x00\x12\x39\n\x12single_int_operand\x18\x04 \x01(\x0b\x32\x1b.chroma.SingleIntComparisonH\x00\x12\x35\n\x10int_list_operand\x18\x05 \x01(\x0b\x32\x19.chroma.IntListComparisonH\x00\x12?\n\x15single_double_operand\x18\x06 \x01(\x0b\x32\x1e.chroma.SingleDoubleComparisonH\x00\x12;\n\x13\x64ouble_list_operand\x18\x07 \x01(\x0b\x32\x1c.chroma.DoubleListComparisonH\x00\x12\x37\n\x11\x62ool_list_operand\x18\x08 \x01(\x0b\x32\x1a.chroma.BoolListComparisonH\x00\x12;\n\x13single_bool_operand\x18\t \x01(\x0b\x32\x1c.chroma.SingleBoolComparisonH\x00\x42\x0c\n\ncomparison\"[\n\rWhereChildren\x12\x1f\n\x08\x63hildren\x18\x01 \x03(\x0b\x32\r.chroma.Where\x12)\n\x08operator\x18\x02 \x01(\x0e\x32\x17.chroma.BooleanOperator\"S\n\x14StringListComparison\x12\x0e\n\x06values\x18\x01 \x03(\t\x12+\n\rlist_operator\x18\x02 \x01(\x0e\x32\x14.chroma.ListOperator\"V\n\x16SingleStringComparison\x12\r\n\x05value\x18\x01 \x01(\t\x12-\n\ncomparator\x18\x02 \x01(\x0e\x32\x19.chroma.GenericComparator\"T\n\x14SingleBoolComparison\x12\r\n\x05value\x18\x01 \x01(\x08\x12-\n\ncomparator\x18\x02 \x01(\x0e\x32\x19.chroma.GenericComparator\"P\n\x11IntListComparison\x12\x0e\n\x06values\x18\x01 \x03(\x03\x12+\n\rlist_operator\x18\x02 \x01(\x0e\x32\x14.chroma.ListOperator\"\xa2\x01\n\x13SingleIntComparison\x12\r\n\x05value\x18\x01 \x01(\x03\x12\x37\n\x12generic_comparator\x18\x02 \x01(\x0e\x32\x19.chroma.GenericComparatorH\x00\x12\x35\n\x11number_comparator\x18\x03 \x01(\x0e\x32\x18.chroma.NumberComparatorH\x00\x42\x0c\n\ncomparator\"S\n\x14\x44oubleListComparison\x12\x0e\n\x06values\x18\x01 \x03(\x01\x12+\n\rlist_operator\x18\x02 \x01(\x0e\x32\x14.chroma.ListOperator\"Q\n\x12\x42oolListComparison\x12\x0e\n\x06values\x18\x01 \x03(\x08\x12+\n\rlist_operator\x18\x02 \x01(\x0e\x32\x14.chroma.ListOperator\"\xa5\x01\n\x16SingleDoubleComparison\x12\r\n\x05value\x18\x01 \x01(\x01\x12\x37\n\x12generic_comparator\x18\x02 \x01(\x0e\x32\x19.chroma.GenericComparatorH\x00\x12\x35\n\x11number_comparator\x18\x03 \x01(\x0e\x32\x18.chroma.NumberComparatorH\x00\x42\x0c\n\ncomparator\"\x83\x01\n\x11GetVectorsRequest\x12\x0b\n\x03ids\x18\x01 \x03(\t\x12\x12\n\nsegment_id\x18\x02 \x01(\t\x12\x15\n\rcollection_id\x18\x03 \x01(\t\x12\x36\n\x0fversion_context\x18\x04 \x01(\x0b\x32\x1d.chroma.RequestVersionContext\"D\n\x12GetVectorsResponse\x12.\n\x07records\x18\x01 \x03(\x0b\x32\x1d.chroma.VectorEmbeddingRecord\"C\n\x15VectorEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1e\n\x06vector\x18\x03 \x01(\x0b\x32\x0e.chroma.Vector\"\xd5\x01\n\x13QueryVectorsRequest\x12\x1f\n\x07vectors\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\t\n\x01k\x18\x02 \x01(\x05\x12\x13\n\x0b\x61llowed_ids\x18\x03 \x03(\t\x12\x1a\n\x12include_embeddings\x18\x04 \x01(\x08\x12\x12\n\nsegment_id\x18\x05 \x01(\t\x12\x15\n\rcollection_id\x18\x06 \x01(\t\x12\x36\n\x0fversion_context\x18\x07 \x01(\x0b\x32\x1d.chroma.RequestVersionContext\"C\n\x14QueryVectorsResponse\x12+\n\x07results\x18\x01 \x03(\x0b\x32\x1a.chroma.VectorQueryResults\"@\n\x12VectorQueryResults\x12*\n\x07results\x18\x01 \x03(\x0b\x32\x19.chroma.VectorQueryResult\"a\n\x11VectorQueryResult\x12\n\n\x02id\x18\x01 \x01(\t\x12\x10\n\x08\x64istance\x18\x03 \x01(\x02\x12#\n\x06vector\x18\x04 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x42\t\n\x07_vector*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01*@\n\x0cSegmentScope\x12\n\n\x06VECTOR\x10\x00\x12\x0c\n\x08METADATA\x10\x01\x12\n\n\x06RECORD\x10\x02\x12\n\n\x06SQLITE\x10\x03*7\n\x15WhereDocumentOperator\x12\x0c\n\x08\x43ONTAINS\x10\x00\x12\x10\n\x0cNOT_CONTAINS\x10\x01*\"\n\x0f\x42ooleanOperator\x12\x07\n\x03\x41ND\x10\x00\x12\x06\n\x02OR\x10\x01*\x1f\n\x0cListOperator\x12\x06\n\x02IN\x10\x00\x12\x07\n\x03NIN\x10\x01*#\n\x11GenericComparator\x12\x06\n\x02\x45Q\x10\x00\x12\x06\n\x02NE\x10\x01*4\n\x10NumberComparator\x12\x06\n\x02GT\x10\x00\x12\x07\n\x03GTE\x10\x01\x12\x06\n\x02LT\x10\x02\x12\x07\n\x03LTE\x10\x03\x32\xad\x01\n\x0eMetadataReader\x12N\n\rQueryMetadata\x12\x1c.chroma.QueryMetadataRequest\x1a\x1d.chroma.QueryMetadataResponse\"\x00\x12K\n\x0c\x43ountRecords\x12\x1b.chroma.CountRecordsRequest\x1a\x1c.chroma.CountRecordsResponse\"\x00\x32\xa2\x01\n\x0cVectorReader\x12\x45\n\nGetVectors\x12\x19.chroma.GetVectorsRequest\x1a\x1a.chroma.GetVectorsResponse\"\x00\x12K\n\x0cQueryVectors\x12\x1b.chroma.QueryVectorsRequest\x1a\x1c.chroma.QueryVectorsResponse\"\x00\x42:Z8github.com/chroma-core/chroma/go/pkg/proto/coordinatorpbb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'chromadb.proto.chroma_pb2', _globals) -if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'Z8github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb' - _SEGMENT_FILEPATHSENTRY._options = None - _SEGMENT_FILEPATHSENTRY._serialized_options = b'8\001' - _UPDATEMETADATA_METADATAENTRY._options = None - _UPDATEMETADATA_METADATAENTRY._serialized_options = b'8\001' - _globals['_OPERATION']._serialized_start=4592 - _globals['_OPERATION']._serialized_end=4648 - _globals['_SCALARENCODING']._serialized_start=4650 - _globals['_SCALARENCODING']._serialized_end=4690 - _globals['_SEGMENTSCOPE']._serialized_start=4692 - _globals['_SEGMENTSCOPE']._serialized_end=4756 - _globals['_WHEREDOCUMENTOPERATOR']._serialized_start=4758 - _globals['_WHEREDOCUMENTOPERATOR']._serialized_end=4813 - _globals['_BOOLEANOPERATOR']._serialized_start=4815 - _globals['_BOOLEANOPERATOR']._serialized_end=4849 - _globals['_LISTOPERATOR']._serialized_start=4851 - _globals['_LISTOPERATOR']._serialized_end=4882 - _globals['_GENERICCOMPARATOR']._serialized_start=4884 - _globals['_GENERICCOMPARATOR']._serialized_end=4919 - _globals['_NUMBERCOMPARATOR']._serialized_start=4921 - _globals['_NUMBERCOMPARATOR']._serialized_end=4973 - _globals['_VECTOR']._serialized_start=39 - _globals['_VECTOR']._serialized_end=124 - _globals['_FILEPATHS']._serialized_start=126 - _globals['_FILEPATHS']._serialized_end=152 - _globals['_SEGMENT']._serialized_start=155 - _globals['_SEGMENT']._serialized_end=428 - _globals['_SEGMENT_FILEPATHSENTRY']._serialized_start=348 - _globals['_SEGMENT_FILEPATHSENTRY']._serialized_end=415 - _globals['_COLLECTION']._serialized_start=431 - _globals['_COLLECTION']._serialized_end=672 - _globals['_DATABASE']._serialized_start=674 - _globals['_DATABASE']._serialized_end=726 - _globals['_TENANT']._serialized_start=728 - _globals['_TENANT']._serialized_end=750 - _globals['_UPDATEMETADATAVALUE']._serialized_start=752 - _globals['_UPDATEMETADATAVALUE']._serialized_end=872 - _globals['_UPDATEMETADATA']._serialized_start=875 - _globals['_UPDATEMETADATA']._serialized_end=1025 - _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_start=949 - _globals['_UPDATEMETADATA_METADATAENTRY']._serialized_end=1025 - _globals['_OPERATIONRECORD']._serialized_start=1028 - _globals['_OPERATIONRECORD']._serialized_end=1203 - _globals['_REQUESTVERSIONCONTEXT']._serialized_start=1205 - _globals['_REQUESTVERSIONCONTEXT']._serialized_end=1278 - _globals['_COUNTRECORDSREQUEST']._serialized_start=1280 - _globals['_COUNTRECORDSREQUEST']._serialized_end=1400 - _globals['_COUNTRECORDSRESPONSE']._serialized_start=1402 - _globals['_COUNTRECORDSRESPONSE']._serialized_end=1439 - _globals['_QUERYMETADATAREQUEST']._serialized_start=1442 - _globals['_QUERYMETADATAREQUEST']._serialized_end=1771 - _globals['_QUERYMETADATARESPONSE']._serialized_start=1773 - _globals['_QUERYMETADATARESPONSE']._serialized_end=1846 - _globals['_METADATAEMBEDDINGRECORD']._serialized_start=1848 - _globals['_METADATAEMBEDDINGRECORD']._serialized_end=1927 - _globals['_USERIDS']._serialized_start=1929 - _globals['_USERIDS']._serialized_end=1951 - _globals['_WHEREDOCUMENT']._serialized_start=1954 - _globals['_WHEREDOCUMENT']._serialized_end=2085 - _globals['_DIRECTWHEREDOCUMENT']._serialized_start=2087 - _globals['_DIRECTWHEREDOCUMENT']._serialized_end=2175 - _globals['_WHEREDOCUMENTCHILDREN']._serialized_start=2177 - _globals['_WHEREDOCUMENTCHILDREN']._serialized_end=2284 - _globals['_WHERE']._serialized_start=2286 - _globals['_WHERE']._serialized_end=2400 - _globals['_DIRECTCOMPARISON']._serialized_start=2403 - _globals['_DIRECTCOMPARISON']._serialized_end=2932 - _globals['_WHERECHILDREN']._serialized_start=2934 - _globals['_WHERECHILDREN']._serialized_end=3025 - _globals['_STRINGLISTCOMPARISON']._serialized_start=3027 - _globals['_STRINGLISTCOMPARISON']._serialized_end=3110 - _globals['_SINGLESTRINGCOMPARISON']._serialized_start=3112 - _globals['_SINGLESTRINGCOMPARISON']._serialized_end=3198 - _globals['_SINGLEBOOLCOMPARISON']._serialized_start=3200 - _globals['_SINGLEBOOLCOMPARISON']._serialized_end=3284 - _globals['_INTLISTCOMPARISON']._serialized_start=3286 - _globals['_INTLISTCOMPARISON']._serialized_end=3366 - _globals['_SINGLEINTCOMPARISON']._serialized_start=3369 - _globals['_SINGLEINTCOMPARISON']._serialized_end=3531 - _globals['_DOUBLELISTCOMPARISON']._serialized_start=3533 - _globals['_DOUBLELISTCOMPARISON']._serialized_end=3616 - _globals['_BOOLLISTCOMPARISON']._serialized_start=3618 - _globals['_BOOLLISTCOMPARISON']._serialized_end=3699 - _globals['_SINGLEDOUBLECOMPARISON']._serialized_start=3702 - _globals['_SINGLEDOUBLECOMPARISON']._serialized_end=3867 - _globals['_GETVECTORSREQUEST']._serialized_start=3870 - _globals['_GETVECTORSREQUEST']._serialized_end=4001 - _globals['_GETVECTORSRESPONSE']._serialized_start=4003 - _globals['_GETVECTORSRESPONSE']._serialized_end=4071 - _globals['_VECTOREMBEDDINGRECORD']._serialized_start=4073 - _globals['_VECTOREMBEDDINGRECORD']._serialized_end=4140 - _globals['_QUERYVECTORSREQUEST']._serialized_start=4143 - _globals['_QUERYVECTORSREQUEST']._serialized_end=4356 - _globals['_QUERYVECTORSRESPONSE']._serialized_start=4358 - _globals['_QUERYVECTORSRESPONSE']._serialized_end=4425 - _globals['_VECTORQUERYRESULTS']._serialized_start=4427 - _globals['_VECTORQUERYRESULTS']._serialized_end=4491 - _globals['_VECTORQUERYRESULT']._serialized_start=4493 - _globals['_VECTORQUERYRESULT']._serialized_end=4590 - _globals['_METADATAREADER']._serialized_start=4976 - _globals['_METADATAREADER']._serialized_end=5149 - _globals['_VECTORREADER']._serialized_start=5152 - _globals['_VECTORREADER']._serialized_end=5314 -# @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/chroma_pb2.pyi b/chromadb/proto/chroma_pb2.pyi deleted file mode 100644 index dc013872cd2..00000000000 --- a/chromadb/proto/chroma_pb2.pyi +++ /dev/null @@ -1,451 +0,0 @@ -from google.protobuf.internal import containers as _containers -from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union - -DESCRIPTOR: _descriptor.FileDescriptor - -class Operation(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] - ADD: _ClassVar[Operation] - UPDATE: _ClassVar[Operation] - UPSERT: _ClassVar[Operation] - DELETE: _ClassVar[Operation] - -class ScalarEncoding(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] - FLOAT32: _ClassVar[ScalarEncoding] - INT32: _ClassVar[ScalarEncoding] - -class SegmentScope(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] - VECTOR: _ClassVar[SegmentScope] - METADATA: _ClassVar[SegmentScope] - RECORD: _ClassVar[SegmentScope] - SQLITE: _ClassVar[SegmentScope] - -class WhereDocumentOperator(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] - CONTAINS: _ClassVar[WhereDocumentOperator] - NOT_CONTAINS: _ClassVar[WhereDocumentOperator] - -class BooleanOperator(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] - AND: _ClassVar[BooleanOperator] - OR: _ClassVar[BooleanOperator] - -class ListOperator(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] - IN: _ClassVar[ListOperator] - NIN: _ClassVar[ListOperator] - -class GenericComparator(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] - EQ: _ClassVar[GenericComparator] - NE: _ClassVar[GenericComparator] - -class NumberComparator(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] - GT: _ClassVar[NumberComparator] - GTE: _ClassVar[NumberComparator] - LT: _ClassVar[NumberComparator] - LTE: _ClassVar[NumberComparator] -ADD: Operation -UPDATE: Operation -UPSERT: Operation -DELETE: Operation -FLOAT32: ScalarEncoding -INT32: ScalarEncoding -VECTOR: SegmentScope -METADATA: SegmentScope -RECORD: SegmentScope -SQLITE: SegmentScope -CONTAINS: WhereDocumentOperator -NOT_CONTAINS: WhereDocumentOperator -AND: BooleanOperator -OR: BooleanOperator -IN: ListOperator -NIN: ListOperator -EQ: GenericComparator -NE: GenericComparator -GT: NumberComparator -GTE: NumberComparator -LT: NumberComparator -LTE: NumberComparator - -class Vector(_message.Message): - __slots__ = ["dimension", "vector", "encoding"] - DIMENSION_FIELD_NUMBER: _ClassVar[int] - VECTOR_FIELD_NUMBER: _ClassVar[int] - ENCODING_FIELD_NUMBER: _ClassVar[int] - dimension: int - vector: bytes - encoding: ScalarEncoding - def __init__(self, dimension: _Optional[int] = ..., vector: _Optional[bytes] = ..., encoding: _Optional[_Union[ScalarEncoding, str]] = ...) -> None: ... - -class FilePaths(_message.Message): - __slots__ = ["paths"] - PATHS_FIELD_NUMBER: _ClassVar[int] - paths: _containers.RepeatedScalarFieldContainer[str] - def __init__(self, paths: _Optional[_Iterable[str]] = ...) -> None: ... - -class Segment(_message.Message): - __slots__ = ["id", "type", "scope", "collection", "metadata", "file_paths"] - class FilePathsEntry(_message.Message): - __slots__ = ["key", "value"] - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: FilePaths - def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[FilePaths, _Mapping]] = ...) -> None: ... - ID_FIELD_NUMBER: _ClassVar[int] - TYPE_FIELD_NUMBER: _ClassVar[int] - SCOPE_FIELD_NUMBER: _ClassVar[int] - COLLECTION_FIELD_NUMBER: _ClassVar[int] - METADATA_FIELD_NUMBER: _ClassVar[int] - FILE_PATHS_FIELD_NUMBER: _ClassVar[int] - id: str - type: str - scope: SegmentScope - collection: str - metadata: UpdateMetadata - file_paths: _containers.MessageMap[str, FilePaths] - def __init__(self, id: _Optional[str] = ..., type: _Optional[str] = ..., scope: _Optional[_Union[SegmentScope, str]] = ..., collection: _Optional[str] = ..., metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ..., file_paths: _Optional[_Mapping[str, FilePaths]] = ...) -> None: ... - -class Collection(_message.Message): - __slots__ = ["id", "name", "configuration_json_str", "metadata", "dimension", "tenant", "database", "log_position", "version"] - ID_FIELD_NUMBER: _ClassVar[int] - NAME_FIELD_NUMBER: _ClassVar[int] - CONFIGURATION_JSON_STR_FIELD_NUMBER: _ClassVar[int] - METADATA_FIELD_NUMBER: _ClassVar[int] - DIMENSION_FIELD_NUMBER: _ClassVar[int] - TENANT_FIELD_NUMBER: _ClassVar[int] - DATABASE_FIELD_NUMBER: _ClassVar[int] - LOG_POSITION_FIELD_NUMBER: _ClassVar[int] - VERSION_FIELD_NUMBER: _ClassVar[int] - id: str - name: str - configuration_json_str: str - metadata: UpdateMetadata - dimension: int - tenant: str - database: str - log_position: int - version: int - def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., configuration_json_str: _Optional[str] = ..., metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ..., dimension: _Optional[int] = ..., tenant: _Optional[str] = ..., database: _Optional[str] = ..., log_position: _Optional[int] = ..., version: _Optional[int] = ...) -> None: ... - -class Database(_message.Message): - __slots__ = ["id", "name", "tenant"] - ID_FIELD_NUMBER: _ClassVar[int] - NAME_FIELD_NUMBER: _ClassVar[int] - TENANT_FIELD_NUMBER: _ClassVar[int] - id: str - name: str - tenant: str - def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., tenant: _Optional[str] = ...) -> None: ... - -class Tenant(_message.Message): - __slots__ = ["name"] - NAME_FIELD_NUMBER: _ClassVar[int] - name: str - def __init__(self, name: _Optional[str] = ...) -> None: ... - -class UpdateMetadataValue(_message.Message): - __slots__ = ["string_value", "int_value", "float_value", "bool_value"] - STRING_VALUE_FIELD_NUMBER: _ClassVar[int] - INT_VALUE_FIELD_NUMBER: _ClassVar[int] - FLOAT_VALUE_FIELD_NUMBER: _ClassVar[int] - BOOL_VALUE_FIELD_NUMBER: _ClassVar[int] - string_value: str - int_value: int - float_value: float - bool_value: bool - def __init__(self, string_value: _Optional[str] = ..., int_value: _Optional[int] = ..., float_value: _Optional[float] = ..., bool_value: bool = ...) -> None: ... - -class UpdateMetadata(_message.Message): - __slots__ = ["metadata"] - class MetadataEntry(_message.Message): - __slots__ = ["key", "value"] - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: UpdateMetadataValue - def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[UpdateMetadataValue, _Mapping]] = ...) -> None: ... - METADATA_FIELD_NUMBER: _ClassVar[int] - metadata: _containers.MessageMap[str, UpdateMetadataValue] - def __init__(self, metadata: _Optional[_Mapping[str, UpdateMetadataValue]] = ...) -> None: ... - -class OperationRecord(_message.Message): - __slots__ = ["id", "vector", "metadata", "operation"] - ID_FIELD_NUMBER: _ClassVar[int] - VECTOR_FIELD_NUMBER: _ClassVar[int] - METADATA_FIELD_NUMBER: _ClassVar[int] - OPERATION_FIELD_NUMBER: _ClassVar[int] - id: str - vector: Vector - metadata: UpdateMetadata - operation: Operation - def __init__(self, id: _Optional[str] = ..., vector: _Optional[_Union[Vector, _Mapping]] = ..., metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ..., operation: _Optional[_Union[Operation, str]] = ...) -> None: ... - -class RequestVersionContext(_message.Message): - __slots__ = ["collection_version", "log_position"] - COLLECTION_VERSION_FIELD_NUMBER: _ClassVar[int] - LOG_POSITION_FIELD_NUMBER: _ClassVar[int] - collection_version: int - log_position: int - def __init__(self, collection_version: _Optional[int] = ..., log_position: _Optional[int] = ...) -> None: ... - -class CountRecordsRequest(_message.Message): - __slots__ = ["segment_id", "collection_id", "version_context"] - SEGMENT_ID_FIELD_NUMBER: _ClassVar[int] - COLLECTION_ID_FIELD_NUMBER: _ClassVar[int] - VERSION_CONTEXT_FIELD_NUMBER: _ClassVar[int] - segment_id: str - collection_id: str - version_context: RequestVersionContext - def __init__(self, segment_id: _Optional[str] = ..., collection_id: _Optional[str] = ..., version_context: _Optional[_Union[RequestVersionContext, _Mapping]] = ...) -> None: ... - -class CountRecordsResponse(_message.Message): - __slots__ = ["count"] - COUNT_FIELD_NUMBER: _ClassVar[int] - count: int - def __init__(self, count: _Optional[int] = ...) -> None: ... - -class QueryMetadataRequest(_message.Message): - __slots__ = ["segment_id", "where", "where_document", "ids", "limit", "offset", "collection_id", "include_metadata", "version_context"] - SEGMENT_ID_FIELD_NUMBER: _ClassVar[int] - WHERE_FIELD_NUMBER: _ClassVar[int] - WHERE_DOCUMENT_FIELD_NUMBER: _ClassVar[int] - IDS_FIELD_NUMBER: _ClassVar[int] - LIMIT_FIELD_NUMBER: _ClassVar[int] - OFFSET_FIELD_NUMBER: _ClassVar[int] - COLLECTION_ID_FIELD_NUMBER: _ClassVar[int] - INCLUDE_METADATA_FIELD_NUMBER: _ClassVar[int] - VERSION_CONTEXT_FIELD_NUMBER: _ClassVar[int] - segment_id: str - where: Where - where_document: WhereDocument - ids: UserIds - limit: int - offset: int - collection_id: str - include_metadata: bool - version_context: RequestVersionContext - def __init__(self, segment_id: _Optional[str] = ..., where: _Optional[_Union[Where, _Mapping]] = ..., where_document: _Optional[_Union[WhereDocument, _Mapping]] = ..., ids: _Optional[_Union[UserIds, _Mapping]] = ..., limit: _Optional[int] = ..., offset: _Optional[int] = ..., collection_id: _Optional[str] = ..., include_metadata: bool = ..., version_context: _Optional[_Union[RequestVersionContext, _Mapping]] = ...) -> None: ... - -class QueryMetadataResponse(_message.Message): - __slots__ = ["records"] - RECORDS_FIELD_NUMBER: _ClassVar[int] - records: _containers.RepeatedCompositeFieldContainer[MetadataEmbeddingRecord] - def __init__(self, records: _Optional[_Iterable[_Union[MetadataEmbeddingRecord, _Mapping]]] = ...) -> None: ... - -class MetadataEmbeddingRecord(_message.Message): - __slots__ = ["id", "metadata"] - ID_FIELD_NUMBER: _ClassVar[int] - METADATA_FIELD_NUMBER: _ClassVar[int] - id: str - metadata: UpdateMetadata - def __init__(self, id: _Optional[str] = ..., metadata: _Optional[_Union[UpdateMetadata, _Mapping]] = ...) -> None: ... - -class UserIds(_message.Message): - __slots__ = ["ids"] - IDS_FIELD_NUMBER: _ClassVar[int] - ids: _containers.RepeatedScalarFieldContainer[str] - def __init__(self, ids: _Optional[_Iterable[str]] = ...) -> None: ... - -class WhereDocument(_message.Message): - __slots__ = ["direct", "children"] - DIRECT_FIELD_NUMBER: _ClassVar[int] - CHILDREN_FIELD_NUMBER: _ClassVar[int] - direct: DirectWhereDocument - children: WhereDocumentChildren - def __init__(self, direct: _Optional[_Union[DirectWhereDocument, _Mapping]] = ..., children: _Optional[_Union[WhereDocumentChildren, _Mapping]] = ...) -> None: ... - -class DirectWhereDocument(_message.Message): - __slots__ = ["document", "operator"] - DOCUMENT_FIELD_NUMBER: _ClassVar[int] - OPERATOR_FIELD_NUMBER: _ClassVar[int] - document: str - operator: WhereDocumentOperator - def __init__(self, document: _Optional[str] = ..., operator: _Optional[_Union[WhereDocumentOperator, str]] = ...) -> None: ... - -class WhereDocumentChildren(_message.Message): - __slots__ = ["children", "operator"] - CHILDREN_FIELD_NUMBER: _ClassVar[int] - OPERATOR_FIELD_NUMBER: _ClassVar[int] - children: _containers.RepeatedCompositeFieldContainer[WhereDocument] - operator: BooleanOperator - def __init__(self, children: _Optional[_Iterable[_Union[WhereDocument, _Mapping]]] = ..., operator: _Optional[_Union[BooleanOperator, str]] = ...) -> None: ... - -class Where(_message.Message): - __slots__ = ["direct_comparison", "children"] - DIRECT_COMPARISON_FIELD_NUMBER: _ClassVar[int] - CHILDREN_FIELD_NUMBER: _ClassVar[int] - direct_comparison: DirectComparison - children: WhereChildren - def __init__(self, direct_comparison: _Optional[_Union[DirectComparison, _Mapping]] = ..., children: _Optional[_Union[WhereChildren, _Mapping]] = ...) -> None: ... - -class DirectComparison(_message.Message): - __slots__ = ["key", "single_string_operand", "string_list_operand", "single_int_operand", "int_list_operand", "single_double_operand", "double_list_operand", "bool_list_operand", "single_bool_operand"] - KEY_FIELD_NUMBER: _ClassVar[int] - SINGLE_STRING_OPERAND_FIELD_NUMBER: _ClassVar[int] - STRING_LIST_OPERAND_FIELD_NUMBER: _ClassVar[int] - SINGLE_INT_OPERAND_FIELD_NUMBER: _ClassVar[int] - INT_LIST_OPERAND_FIELD_NUMBER: _ClassVar[int] - SINGLE_DOUBLE_OPERAND_FIELD_NUMBER: _ClassVar[int] - DOUBLE_LIST_OPERAND_FIELD_NUMBER: _ClassVar[int] - BOOL_LIST_OPERAND_FIELD_NUMBER: _ClassVar[int] - SINGLE_BOOL_OPERAND_FIELD_NUMBER: _ClassVar[int] - key: str - single_string_operand: SingleStringComparison - string_list_operand: StringListComparison - single_int_operand: SingleIntComparison - int_list_operand: IntListComparison - single_double_operand: SingleDoubleComparison - double_list_operand: DoubleListComparison - bool_list_operand: BoolListComparison - single_bool_operand: SingleBoolComparison - def __init__(self, key: _Optional[str] = ..., single_string_operand: _Optional[_Union[SingleStringComparison, _Mapping]] = ..., string_list_operand: _Optional[_Union[StringListComparison, _Mapping]] = ..., single_int_operand: _Optional[_Union[SingleIntComparison, _Mapping]] = ..., int_list_operand: _Optional[_Union[IntListComparison, _Mapping]] = ..., single_double_operand: _Optional[_Union[SingleDoubleComparison, _Mapping]] = ..., double_list_operand: _Optional[_Union[DoubleListComparison, _Mapping]] = ..., bool_list_operand: _Optional[_Union[BoolListComparison, _Mapping]] = ..., single_bool_operand: _Optional[_Union[SingleBoolComparison, _Mapping]] = ...) -> None: ... - -class WhereChildren(_message.Message): - __slots__ = ["children", "operator"] - CHILDREN_FIELD_NUMBER: _ClassVar[int] - OPERATOR_FIELD_NUMBER: _ClassVar[int] - children: _containers.RepeatedCompositeFieldContainer[Where] - operator: BooleanOperator - def __init__(self, children: _Optional[_Iterable[_Union[Where, _Mapping]]] = ..., operator: _Optional[_Union[BooleanOperator, str]] = ...) -> None: ... - -class StringListComparison(_message.Message): - __slots__ = ["values", "list_operator"] - VALUES_FIELD_NUMBER: _ClassVar[int] - LIST_OPERATOR_FIELD_NUMBER: _ClassVar[int] - values: _containers.RepeatedScalarFieldContainer[str] - list_operator: ListOperator - def __init__(self, values: _Optional[_Iterable[str]] = ..., list_operator: _Optional[_Union[ListOperator, str]] = ...) -> None: ... - -class SingleStringComparison(_message.Message): - __slots__ = ["value", "comparator"] - VALUE_FIELD_NUMBER: _ClassVar[int] - COMPARATOR_FIELD_NUMBER: _ClassVar[int] - value: str - comparator: GenericComparator - def __init__(self, value: _Optional[str] = ..., comparator: _Optional[_Union[GenericComparator, str]] = ...) -> None: ... - -class SingleBoolComparison(_message.Message): - __slots__ = ["value", "comparator"] - VALUE_FIELD_NUMBER: _ClassVar[int] - COMPARATOR_FIELD_NUMBER: _ClassVar[int] - value: bool - comparator: GenericComparator - def __init__(self, value: bool = ..., comparator: _Optional[_Union[GenericComparator, str]] = ...) -> None: ... - -class IntListComparison(_message.Message): - __slots__ = ["values", "list_operator"] - VALUES_FIELD_NUMBER: _ClassVar[int] - LIST_OPERATOR_FIELD_NUMBER: _ClassVar[int] - values: _containers.RepeatedScalarFieldContainer[int] - list_operator: ListOperator - def __init__(self, values: _Optional[_Iterable[int]] = ..., list_operator: _Optional[_Union[ListOperator, str]] = ...) -> None: ... - -class SingleIntComparison(_message.Message): - __slots__ = ["value", "generic_comparator", "number_comparator"] - VALUE_FIELD_NUMBER: _ClassVar[int] - GENERIC_COMPARATOR_FIELD_NUMBER: _ClassVar[int] - NUMBER_COMPARATOR_FIELD_NUMBER: _ClassVar[int] - value: int - generic_comparator: GenericComparator - number_comparator: NumberComparator - def __init__(self, value: _Optional[int] = ..., generic_comparator: _Optional[_Union[GenericComparator, str]] = ..., number_comparator: _Optional[_Union[NumberComparator, str]] = ...) -> None: ... - -class DoubleListComparison(_message.Message): - __slots__ = ["values", "list_operator"] - VALUES_FIELD_NUMBER: _ClassVar[int] - LIST_OPERATOR_FIELD_NUMBER: _ClassVar[int] - values: _containers.RepeatedScalarFieldContainer[float] - list_operator: ListOperator - def __init__(self, values: _Optional[_Iterable[float]] = ..., list_operator: _Optional[_Union[ListOperator, str]] = ...) -> None: ... - -class BoolListComparison(_message.Message): - __slots__ = ["values", "list_operator"] - VALUES_FIELD_NUMBER: _ClassVar[int] - LIST_OPERATOR_FIELD_NUMBER: _ClassVar[int] - values: _containers.RepeatedScalarFieldContainer[bool] - list_operator: ListOperator - def __init__(self, values: _Optional[_Iterable[bool]] = ..., list_operator: _Optional[_Union[ListOperator, str]] = ...) -> None: ... - -class SingleDoubleComparison(_message.Message): - __slots__ = ["value", "generic_comparator", "number_comparator"] - VALUE_FIELD_NUMBER: _ClassVar[int] - GENERIC_COMPARATOR_FIELD_NUMBER: _ClassVar[int] - NUMBER_COMPARATOR_FIELD_NUMBER: _ClassVar[int] - value: float - generic_comparator: GenericComparator - number_comparator: NumberComparator - def __init__(self, value: _Optional[float] = ..., generic_comparator: _Optional[_Union[GenericComparator, str]] = ..., number_comparator: _Optional[_Union[NumberComparator, str]] = ...) -> None: ... - -class GetVectorsRequest(_message.Message): - __slots__ = ["ids", "segment_id", "collection_id", "version_context"] - IDS_FIELD_NUMBER: _ClassVar[int] - SEGMENT_ID_FIELD_NUMBER: _ClassVar[int] - COLLECTION_ID_FIELD_NUMBER: _ClassVar[int] - VERSION_CONTEXT_FIELD_NUMBER: _ClassVar[int] - ids: _containers.RepeatedScalarFieldContainer[str] - segment_id: str - collection_id: str - version_context: RequestVersionContext - def __init__(self, ids: _Optional[_Iterable[str]] = ..., segment_id: _Optional[str] = ..., collection_id: _Optional[str] = ..., version_context: _Optional[_Union[RequestVersionContext, _Mapping]] = ...) -> None: ... - -class GetVectorsResponse(_message.Message): - __slots__ = ["records"] - RECORDS_FIELD_NUMBER: _ClassVar[int] - records: _containers.RepeatedCompositeFieldContainer[VectorEmbeddingRecord] - def __init__(self, records: _Optional[_Iterable[_Union[VectorEmbeddingRecord, _Mapping]]] = ...) -> None: ... - -class VectorEmbeddingRecord(_message.Message): - __slots__ = ["id", "vector"] - ID_FIELD_NUMBER: _ClassVar[int] - VECTOR_FIELD_NUMBER: _ClassVar[int] - id: str - vector: Vector - def __init__(self, id: _Optional[str] = ..., vector: _Optional[_Union[Vector, _Mapping]] = ...) -> None: ... - -class QueryVectorsRequest(_message.Message): - __slots__ = ["vectors", "k", "allowed_ids", "include_embeddings", "segment_id", "collection_id", "version_context"] - VECTORS_FIELD_NUMBER: _ClassVar[int] - K_FIELD_NUMBER: _ClassVar[int] - ALLOWED_IDS_FIELD_NUMBER: _ClassVar[int] - INCLUDE_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int] - SEGMENT_ID_FIELD_NUMBER: _ClassVar[int] - COLLECTION_ID_FIELD_NUMBER: _ClassVar[int] - VERSION_CONTEXT_FIELD_NUMBER: _ClassVar[int] - vectors: _containers.RepeatedCompositeFieldContainer[Vector] - k: int - allowed_ids: _containers.RepeatedScalarFieldContainer[str] - include_embeddings: bool - segment_id: str - collection_id: str - version_context: RequestVersionContext - def __init__(self, vectors: _Optional[_Iterable[_Union[Vector, _Mapping]]] = ..., k: _Optional[int] = ..., allowed_ids: _Optional[_Iterable[str]] = ..., include_embeddings: bool = ..., segment_id: _Optional[str] = ..., collection_id: _Optional[str] = ..., version_context: _Optional[_Union[RequestVersionContext, _Mapping]] = ...) -> None: ... - -class QueryVectorsResponse(_message.Message): - __slots__ = ["results"] - RESULTS_FIELD_NUMBER: _ClassVar[int] - results: _containers.RepeatedCompositeFieldContainer[VectorQueryResults] - def __init__(self, results: _Optional[_Iterable[_Union[VectorQueryResults, _Mapping]]] = ...) -> None: ... - -class VectorQueryResults(_message.Message): - __slots__ = ["results"] - RESULTS_FIELD_NUMBER: _ClassVar[int] - results: _containers.RepeatedCompositeFieldContainer[VectorQueryResult] - def __init__(self, results: _Optional[_Iterable[_Union[VectorQueryResult, _Mapping]]] = ...) -> None: ... - -class VectorQueryResult(_message.Message): - __slots__ = ["id", "distance", "vector"] - ID_FIELD_NUMBER: _ClassVar[int] - DISTANCE_FIELD_NUMBER: _ClassVar[int] - VECTOR_FIELD_NUMBER: _ClassVar[int] - id: str - distance: float - vector: Vector - def __init__(self, id: _Optional[str] = ..., distance: _Optional[float] = ..., vector: _Optional[_Union[Vector, _Mapping]] = ...) -> None: ... diff --git a/chromadb/proto/chroma_pb2_grpc.py b/chromadb/proto/chroma_pb2_grpc.py deleted file mode 100644 index 6a0bad681f7..00000000000 --- a/chromadb/proto/chroma_pb2_grpc.py +++ /dev/null @@ -1,205 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc - -from chromadb.proto import chroma_pb2 as chromadb_dot_proto_dot_chroma__pb2 - - -class MetadataReaderStub(object): - """Metadata Reader Interface - - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.QueryMetadata = channel.unary_unary( - '/chroma.MetadataReader/QueryMetadata', - request_serializer=chromadb_dot_proto_dot_chroma__pb2.QueryMetadataRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_chroma__pb2.QueryMetadataResponse.FromString, - ) - self.CountRecords = channel.unary_unary( - '/chroma.MetadataReader/CountRecords', - request_serializer=chromadb_dot_proto_dot_chroma__pb2.CountRecordsRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_chroma__pb2.CountRecordsResponse.FromString, - ) - - -class MetadataReaderServicer(object): - """Metadata Reader Interface - - """ - - def QueryMetadata(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def CountRecords(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_MetadataReaderServicer_to_server(servicer, server): - rpc_method_handlers = { - 'QueryMetadata': grpc.unary_unary_rpc_method_handler( - servicer.QueryMetadata, - request_deserializer=chromadb_dot_proto_dot_chroma__pb2.QueryMetadataRequest.FromString, - response_serializer=chromadb_dot_proto_dot_chroma__pb2.QueryMetadataResponse.SerializeToString, - ), - 'CountRecords': grpc.unary_unary_rpc_method_handler( - servicer.CountRecords, - request_deserializer=chromadb_dot_proto_dot_chroma__pb2.CountRecordsRequest.FromString, - response_serializer=chromadb_dot_proto_dot_chroma__pb2.CountRecordsResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'chroma.MetadataReader', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - - - # This class is part of an EXPERIMENTAL API. -class MetadataReader(object): - """Metadata Reader Interface - - """ - - @staticmethod - def QueryMetadata(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.MetadataReader/QueryMetadata', - chromadb_dot_proto_dot_chroma__pb2.QueryMetadataRequest.SerializeToString, - chromadb_dot_proto_dot_chroma__pb2.QueryMetadataResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def CountRecords(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.MetadataReader/CountRecords', - chromadb_dot_proto_dot_chroma__pb2.CountRecordsRequest.SerializeToString, - chromadb_dot_proto_dot_chroma__pb2.CountRecordsResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - -class VectorReaderStub(object): - """Vector Reader Interface - - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.GetVectors = channel.unary_unary( - '/chroma.VectorReader/GetVectors', - request_serializer=chromadb_dot_proto_dot_chroma__pb2.GetVectorsRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_chroma__pb2.GetVectorsResponse.FromString, - ) - self.QueryVectors = channel.unary_unary( - '/chroma.VectorReader/QueryVectors', - request_serializer=chromadb_dot_proto_dot_chroma__pb2.QueryVectorsRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_chroma__pb2.QueryVectorsResponse.FromString, - ) - - -class VectorReaderServicer(object): - """Vector Reader Interface - - """ - - def GetVectors(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def QueryVectors(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_VectorReaderServicer_to_server(servicer, server): - rpc_method_handlers = { - 'GetVectors': grpc.unary_unary_rpc_method_handler( - servicer.GetVectors, - request_deserializer=chromadb_dot_proto_dot_chroma__pb2.GetVectorsRequest.FromString, - response_serializer=chromadb_dot_proto_dot_chroma__pb2.GetVectorsResponse.SerializeToString, - ), - 'QueryVectors': grpc.unary_unary_rpc_method_handler( - servicer.QueryVectors, - request_deserializer=chromadb_dot_proto_dot_chroma__pb2.QueryVectorsRequest.FromString, - response_serializer=chromadb_dot_proto_dot_chroma__pb2.QueryVectorsResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'chroma.VectorReader', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - - - # This class is part of an EXPERIMENTAL API. -class VectorReader(object): - """Vector Reader Interface - - """ - - @staticmethod - def GetVectors(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.VectorReader/GetVectors', - chromadb_dot_proto_dot_chroma__pb2.GetVectorsRequest.SerializeToString, - chromadb_dot_proto_dot_chroma__pb2.GetVectorsResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def QueryVectors(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.VectorReader/QueryVectors', - chromadb_dot_proto_dot_chroma__pb2.QueryVectorsRequest.SerializeToString, - chromadb_dot_proto_dot_chroma__pb2.QueryVectorsResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/chromadb/proto/convert.py b/chromadb/proto/convert.py index 47f5e0e08b9..51d30bc608a 100644 --- a/chromadb/proto/convert.py +++ b/chromadb/proto/convert.py @@ -13,7 +13,7 @@ Filter, Limit, Projection, - SegmentScan, + Scan, ) from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan from chromadb.types import ( @@ -161,6 +161,7 @@ def from_proto_segment(segment: chroma_pb.Segment) -> Segment: metadata=from_proto_metadata(segment.metadata) if segment.HasField("metadata") else None, + file_paths={name: [path for path in paths.paths] for name, paths in segment.file_paths.items()} ) @@ -173,6 +174,7 @@ def to_proto_segment(segment: Segment) -> chroma_pb.Segment: metadata=None if segment["metadata"] is None else to_proto_update_metadata(segment["metadata"]), + file_paths={name: chroma_pb.FilePaths(paths=paths) for name, paths in segment["file_paths"].items()} ) @@ -568,12 +570,12 @@ def to_proto_where_document(where_document: WhereDocument) -> chroma_pb.WhereDoc return response -def to_proto_scan(scan: SegmentScan) -> query_pb.ScanOperator: +def to_proto_scan(scan: Scan) -> query_pb.ScanOperator: return query_pb.ScanOperator( collection=to_proto_collection(scan.collection), - knn_id=scan.knn_id.hex, - metadata_id=scan.metadata_id.hex, - record_id=scan.record_id.hex, + knn=to_proto_segment(scan.knn), + metadata=to_proto_segment(scan.metadata), + record=to_proto_segment(scan.record), ) diff --git a/chromadb/proto/coordinator_pb2.py b/chromadb/proto/coordinator_pb2.py deleted file mode 100644 index 45e95958415..00000000000 --- a/chromadb/proto/coordinator_pb2.py +++ /dev/null @@ -1,101 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: chromadb/proto/coordinator.proto -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from chromadb.proto import chroma_pb2 as chromadb_dot_proto_dot_chroma__pb2 -from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n chromadb/proto/coordinator.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\x1a\x1bgoogle/protobuf/empty.proto\"A\n\x15\x43reateDatabaseRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0e\n\x06tenant\x18\x03 \x01(\t\"&\n\x16\x43reateDatabaseResponseJ\x04\x08\x01\x10\x02R\x06status\"2\n\x12GetDatabaseRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06tenant\x18\x02 \x01(\t\"G\n\x13GetDatabaseResponse\x12\"\n\x08\x64\x61tabase\x18\x01 \x01(\x0b\x32\x10.chroma.DatabaseJ\x04\x08\x02\x10\x03R\x06status\"#\n\x13\x43reateTenantRequest\x12\x0c\n\x04name\x18\x02 \x01(\t\"$\n\x14\x43reateTenantResponseJ\x04\x08\x01\x10\x02R\x06status\" \n\x10GetTenantRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"A\n\x11GetTenantResponse\x12\x1e\n\x06tenant\x18\x01 \x01(\x0b\x32\x0e.chroma.TenantJ\x04\x08\x02\x10\x03R\x06status\"8\n\x14\x43reateSegmentRequest\x12 \n\x07segment\x18\x01 \x01(\x0b\x32\x0f.chroma.Segment\"%\n\x15\x43reateSegmentResponseJ\x04\x08\x01\x10\x02R\x06status\"6\n\x14\x44\x65leteSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\ncollection\x18\x02 \x01(\t\"%\n\x15\x44\x65leteSegmentResponseJ\x04\x08\x01\x10\x02R\x06status\"\x90\x01\n\x12GetSegmentsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04type\x18\x02 \x01(\tH\x01\x88\x01\x01\x12(\n\x05scope\x18\x03 \x01(\x0e\x32\x14.chroma.SegmentScopeH\x02\x88\x01\x01\x12\x12\n\ncollection\x18\x04 \x01(\tB\x05\n\x03_idB\x07\n\x05_typeB\x08\n\x06_scope\"F\n\x13GetSegmentsResponse\x12!\n\x08segments\x18\x01 \x03(\x0b\x32\x0f.chroma.SegmentJ\x04\x08\x02\x10\x03R\x06status\"\x8f\x01\n\x14UpdateSegmentRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\ncollection\x18\x04 \x01(\t\x12*\n\x08metadata\x18\x06 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x12\x18\n\x0ereset_metadata\x18\x07 \x01(\x08H\x00\x42\x11\n\x0fmetadata_update\"%\n\x15UpdateSegmentResponseJ\x04\x08\x01\x10\x02R\x06status\"\xa8\x02\n\x17\x43reateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x1e\n\x16\x63onfiguration_json_str\x18\x03 \x01(\t\x12-\n\x08metadata\x18\x04 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x88\x01\x01\x12\x16\n\tdimension\x18\x05 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rget_or_create\x18\x06 \x01(\x08H\x02\x88\x01\x01\x12\x0e\n\x06tenant\x18\x07 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x08 \x01(\t\x12!\n\x08segments\x18\t \x03(\x0b\x32\x0f.chroma.SegmentB\x0b\n\t_metadataB\x0c\n\n_dimensionB\x10\n\x0e_get_or_create\"a\n\x18\x43reateCollectionResponse\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x0f\n\x07\x63reated\x18\x02 \x01(\x08J\x04\x08\x03\x10\x04R\x06status\"\\\n\x17\x44\x65leteCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06tenant\x18\x02 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\x12\x13\n\x0bsegment_ids\x18\x04 \x03(\t\"(\n\x18\x44\x65leteCollectionResponseJ\x04\x08\x01\x10\x02R\x06status\"\xab\x01\n\x15GetCollectionsRequest\x12\x0f\n\x02id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x0e\n\x06tenant\x18\x04 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x05 \x01(\t\x12\x12\n\x05limit\x18\x06 \x01(\x05H\x02\x88\x01\x01\x12\x13\n\x06offset\x18\x07 \x01(\x05H\x03\x88\x01\x01\x42\x05\n\x03_idB\x07\n\x05_nameB\x08\n\x06_limitB\t\n\x07_offset\"O\n\x16GetCollectionsResponse\x12\'\n\x0b\x63ollections\x18\x01 \x03(\x0b\x32\x12.chroma.CollectionJ\x04\x08\x02\x10\x03R\x06status\"1\n\x17\x43heckCollectionsRequest\x12\x16\n\x0e\x63ollection_ids\x18\x01 \x03(\t\"+\n\x18\x43heckCollectionsResponse\x12\x0f\n\x07\x64\x65leted\x18\x01 \x03(\x08\"\xc0\x01\n\x17UpdateCollectionRequest\x12\n\n\x02id\x18\x01 \x01(\t\x12\x11\n\x04name\x18\x03 \x01(\tH\x01\x88\x01\x01\x12\x16\n\tdimension\x18\x04 \x01(\x05H\x02\x88\x01\x01\x12*\n\x08metadata\x18\x05 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x00\x12\x18\n\x0ereset_metadata\x18\x06 \x01(\x08H\x00\x42\x11\n\x0fmetadata_updateB\x07\n\x05_nameB\x0c\n\n_dimension\"(\n\x18UpdateCollectionResponseJ\x04\x08\x01\x10\x02R\x06status\"\"\n\x12ResetStateResponseJ\x04\x08\x01\x10\x02R\x06status\":\n%GetLastCompactionTimeForTenantRequest\x12\x11\n\ttenant_id\x18\x01 \x03(\t\"K\n\x18TenantLastCompactionTime\x12\x11\n\ttenant_id\x18\x01 \x01(\t\x12\x1c\n\x14last_compaction_time\x18\x02 \x01(\x03\"o\n&GetLastCompactionTimeForTenantResponse\x12\x45\n\x1btenant_last_compaction_time\x18\x01 \x03(\x0b\x32 .chroma.TenantLastCompactionTime\"n\n%SetLastCompactionTimeForTenantRequest\x12\x45\n\x1btenant_last_compaction_time\x18\x01 \x01(\x0b\x32 .chroma.TenantLastCompactionTime\"\xbc\x01\n\x1a\x46lushSegmentCompactionInfo\x12\x12\n\nsegment_id\x18\x01 \x01(\t\x12\x45\n\nfile_paths\x18\x02 \x03(\x0b\x32\x31.chroma.FlushSegmentCompactionInfo.FilePathsEntry\x1a\x43\n\x0e\x46ilePathsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.chroma.FilePaths:\x02\x38\x01\"\xc3\x01\n FlushCollectionCompactionRequest\x12\x11\n\ttenant_id\x18\x01 \x01(\t\x12\x15\n\rcollection_id\x18\x02 \x01(\t\x12\x14\n\x0clog_position\x18\x03 \x01(\x03\x12\x1a\n\x12\x63ollection_version\x18\x04 \x01(\x05\x12\x43\n\x17segment_compaction_info\x18\x05 \x03(\x0b\x32\".chroma.FlushSegmentCompactionInfo\"t\n!FlushCollectionCompactionResponse\x12\x15\n\rcollection_id\x18\x01 \x01(\t\x12\x1a\n\x12\x63ollection_version\x18\x02 \x01(\x05\x12\x1c\n\x14last_compaction_time\x18\x03 \x01(\x03\x32\xcd\x0b\n\x05SysDB\x12Q\n\x0e\x43reateDatabase\x12\x1d.chroma.CreateDatabaseRequest\x1a\x1e.chroma.CreateDatabaseResponse\"\x00\x12H\n\x0bGetDatabase\x12\x1a.chroma.GetDatabaseRequest\x1a\x1b.chroma.GetDatabaseResponse\"\x00\x12K\n\x0c\x43reateTenant\x12\x1b.chroma.CreateTenantRequest\x1a\x1c.chroma.CreateTenantResponse\"\x00\x12\x42\n\tGetTenant\x12\x18.chroma.GetTenantRequest\x1a\x19.chroma.GetTenantResponse\"\x00\x12N\n\rCreateSegment\x12\x1c.chroma.CreateSegmentRequest\x1a\x1d.chroma.CreateSegmentResponse\"\x00\x12N\n\rDeleteSegment\x12\x1c.chroma.DeleteSegmentRequest\x1a\x1d.chroma.DeleteSegmentResponse\"\x00\x12H\n\x0bGetSegments\x12\x1a.chroma.GetSegmentsRequest\x1a\x1b.chroma.GetSegmentsResponse\"\x00\x12N\n\rUpdateSegment\x12\x1c.chroma.UpdateSegmentRequest\x1a\x1d.chroma.UpdateSegmentResponse\"\x00\x12W\n\x10\x43reateCollection\x12\x1f.chroma.CreateCollectionRequest\x1a .chroma.CreateCollectionResponse\"\x00\x12W\n\x10\x44\x65leteCollection\x12\x1f.chroma.DeleteCollectionRequest\x1a .chroma.DeleteCollectionResponse\"\x00\x12Q\n\x0eGetCollections\x12\x1d.chroma.GetCollectionsRequest\x1a\x1e.chroma.GetCollectionsResponse\"\x00\x12W\n\x10\x43heckCollections\x12\x1f.chroma.CheckCollectionsRequest\x1a .chroma.CheckCollectionsResponse\"\x00\x12W\n\x10UpdateCollection\x12\x1f.chroma.UpdateCollectionRequest\x1a .chroma.UpdateCollectionResponse\"\x00\x12\x42\n\nResetState\x12\x16.google.protobuf.Empty\x1a\x1a.chroma.ResetStateResponse\"\x00\x12\x81\x01\n\x1eGetLastCompactionTimeForTenant\x12-.chroma.GetLastCompactionTimeForTenantRequest\x1a..chroma.GetLastCompactionTimeForTenantResponse\"\x00\x12i\n\x1eSetLastCompactionTimeForTenant\x12-.chroma.SetLastCompactionTimeForTenantRequest\x1a\x16.google.protobuf.Empty\"\x00\x12r\n\x19\x46lushCollectionCompaction\x12(.chroma.FlushCollectionCompactionRequest\x1a).chroma.FlushCollectionCompactionResponse\"\x00\x42:Z8github.com/chroma-core/chroma/go/pkg/proto/coordinatorpbb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'chromadb.proto.coordinator_pb2', _globals) -if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'Z8github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb' - _FLUSHSEGMENTCOMPACTIONINFO_FILEPATHSENTRY._options = None - _FLUSHSEGMENTCOMPACTIONINFO_FILEPATHSENTRY._serialized_options = b'8\001' - _globals['_CREATEDATABASEREQUEST']._serialized_start=102 - _globals['_CREATEDATABASEREQUEST']._serialized_end=167 - _globals['_CREATEDATABASERESPONSE']._serialized_start=169 - _globals['_CREATEDATABASERESPONSE']._serialized_end=207 - _globals['_GETDATABASEREQUEST']._serialized_start=209 - _globals['_GETDATABASEREQUEST']._serialized_end=259 - _globals['_GETDATABASERESPONSE']._serialized_start=261 - _globals['_GETDATABASERESPONSE']._serialized_end=332 - _globals['_CREATETENANTREQUEST']._serialized_start=334 - _globals['_CREATETENANTREQUEST']._serialized_end=369 - _globals['_CREATETENANTRESPONSE']._serialized_start=371 - _globals['_CREATETENANTRESPONSE']._serialized_end=407 - _globals['_GETTENANTREQUEST']._serialized_start=409 - _globals['_GETTENANTREQUEST']._serialized_end=441 - _globals['_GETTENANTRESPONSE']._serialized_start=443 - _globals['_GETTENANTRESPONSE']._serialized_end=508 - _globals['_CREATESEGMENTREQUEST']._serialized_start=510 - _globals['_CREATESEGMENTREQUEST']._serialized_end=566 - _globals['_CREATESEGMENTRESPONSE']._serialized_start=568 - _globals['_CREATESEGMENTRESPONSE']._serialized_end=605 - _globals['_DELETESEGMENTREQUEST']._serialized_start=607 - _globals['_DELETESEGMENTREQUEST']._serialized_end=661 - _globals['_DELETESEGMENTRESPONSE']._serialized_start=663 - _globals['_DELETESEGMENTRESPONSE']._serialized_end=700 - _globals['_GETSEGMENTSREQUEST']._serialized_start=703 - _globals['_GETSEGMENTSREQUEST']._serialized_end=847 - _globals['_GETSEGMENTSRESPONSE']._serialized_start=849 - _globals['_GETSEGMENTSRESPONSE']._serialized_end=919 - _globals['_UPDATESEGMENTREQUEST']._serialized_start=922 - _globals['_UPDATESEGMENTREQUEST']._serialized_end=1065 - _globals['_UPDATESEGMENTRESPONSE']._serialized_start=1067 - _globals['_UPDATESEGMENTRESPONSE']._serialized_end=1104 - _globals['_CREATECOLLECTIONREQUEST']._serialized_start=1107 - _globals['_CREATECOLLECTIONREQUEST']._serialized_end=1403 - _globals['_CREATECOLLECTIONRESPONSE']._serialized_start=1405 - _globals['_CREATECOLLECTIONRESPONSE']._serialized_end=1502 - _globals['_DELETECOLLECTIONREQUEST']._serialized_start=1504 - _globals['_DELETECOLLECTIONREQUEST']._serialized_end=1596 - _globals['_DELETECOLLECTIONRESPONSE']._serialized_start=1598 - _globals['_DELETECOLLECTIONRESPONSE']._serialized_end=1638 - _globals['_GETCOLLECTIONSREQUEST']._serialized_start=1641 - _globals['_GETCOLLECTIONSREQUEST']._serialized_end=1812 - _globals['_GETCOLLECTIONSRESPONSE']._serialized_start=1814 - _globals['_GETCOLLECTIONSRESPONSE']._serialized_end=1893 - _globals['_CHECKCOLLECTIONSREQUEST']._serialized_start=1895 - _globals['_CHECKCOLLECTIONSREQUEST']._serialized_end=1944 - _globals['_CHECKCOLLECTIONSRESPONSE']._serialized_start=1946 - _globals['_CHECKCOLLECTIONSRESPONSE']._serialized_end=1989 - _globals['_UPDATECOLLECTIONREQUEST']._serialized_start=1992 - _globals['_UPDATECOLLECTIONREQUEST']._serialized_end=2184 - _globals['_UPDATECOLLECTIONRESPONSE']._serialized_start=2186 - _globals['_UPDATECOLLECTIONRESPONSE']._serialized_end=2226 - _globals['_RESETSTATERESPONSE']._serialized_start=2228 - _globals['_RESETSTATERESPONSE']._serialized_end=2262 - _globals['_GETLASTCOMPACTIONTIMEFORTENANTREQUEST']._serialized_start=2264 - _globals['_GETLASTCOMPACTIONTIMEFORTENANTREQUEST']._serialized_end=2322 - _globals['_TENANTLASTCOMPACTIONTIME']._serialized_start=2324 - _globals['_TENANTLASTCOMPACTIONTIME']._serialized_end=2399 - _globals['_GETLASTCOMPACTIONTIMEFORTENANTRESPONSE']._serialized_start=2401 - _globals['_GETLASTCOMPACTIONTIMEFORTENANTRESPONSE']._serialized_end=2512 - _globals['_SETLASTCOMPACTIONTIMEFORTENANTREQUEST']._serialized_start=2514 - _globals['_SETLASTCOMPACTIONTIMEFORTENANTREQUEST']._serialized_end=2624 - _globals['_FLUSHSEGMENTCOMPACTIONINFO']._serialized_start=2627 - _globals['_FLUSHSEGMENTCOMPACTIONINFO']._serialized_end=2815 - _globals['_FLUSHSEGMENTCOMPACTIONINFO_FILEPATHSENTRY']._serialized_start=2748 - _globals['_FLUSHSEGMENTCOMPACTIONINFO_FILEPATHSENTRY']._serialized_end=2815 - _globals['_FLUSHCOLLECTIONCOMPACTIONREQUEST']._serialized_start=2818 - _globals['_FLUSHCOLLECTIONCOMPACTIONREQUEST']._serialized_end=3013 - _globals['_FLUSHCOLLECTIONCOMPACTIONRESPONSE']._serialized_start=3015 - _globals['_FLUSHCOLLECTIONCOMPACTIONRESPONSE']._serialized_end=3131 - _globals['_SYSDB']._serialized_start=3134 - _globals['_SYSDB']._serialized_end=4619 -# @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/coordinator_pb2.pyi b/chromadb/proto/coordinator_pb2.pyi deleted file mode 100644 index 4c10bb4a4bf..00000000000 --- a/chromadb/proto/coordinator_pb2.pyi +++ /dev/null @@ -1,281 +0,0 @@ -from chromadb.proto import chroma_pb2 as _chroma_pb2 -from google.protobuf import empty_pb2 as _empty_pb2 -from google.protobuf.internal import containers as _containers -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union - -DESCRIPTOR: _descriptor.FileDescriptor - -class CreateDatabaseRequest(_message.Message): - __slots__ = ["id", "name", "tenant"] - ID_FIELD_NUMBER: _ClassVar[int] - NAME_FIELD_NUMBER: _ClassVar[int] - TENANT_FIELD_NUMBER: _ClassVar[int] - id: str - name: str - tenant: str - def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., tenant: _Optional[str] = ...) -> None: ... - -class CreateDatabaseResponse(_message.Message): - __slots__ = [] - def __init__(self) -> None: ... - -class GetDatabaseRequest(_message.Message): - __slots__ = ["name", "tenant"] - NAME_FIELD_NUMBER: _ClassVar[int] - TENANT_FIELD_NUMBER: _ClassVar[int] - name: str - tenant: str - def __init__(self, name: _Optional[str] = ..., tenant: _Optional[str] = ...) -> None: ... - -class GetDatabaseResponse(_message.Message): - __slots__ = ["database"] - DATABASE_FIELD_NUMBER: _ClassVar[int] - database: _chroma_pb2.Database - def __init__(self, database: _Optional[_Union[_chroma_pb2.Database, _Mapping]] = ...) -> None: ... - -class CreateTenantRequest(_message.Message): - __slots__ = ["name"] - NAME_FIELD_NUMBER: _ClassVar[int] - name: str - def __init__(self, name: _Optional[str] = ...) -> None: ... - -class CreateTenantResponse(_message.Message): - __slots__ = [] - def __init__(self) -> None: ... - -class GetTenantRequest(_message.Message): - __slots__ = ["name"] - NAME_FIELD_NUMBER: _ClassVar[int] - name: str - def __init__(self, name: _Optional[str] = ...) -> None: ... - -class GetTenantResponse(_message.Message): - __slots__ = ["tenant"] - TENANT_FIELD_NUMBER: _ClassVar[int] - tenant: _chroma_pb2.Tenant - def __init__(self, tenant: _Optional[_Union[_chroma_pb2.Tenant, _Mapping]] = ...) -> None: ... - -class CreateSegmentRequest(_message.Message): - __slots__ = ["segment"] - SEGMENT_FIELD_NUMBER: _ClassVar[int] - segment: _chroma_pb2.Segment - def __init__(self, segment: _Optional[_Union[_chroma_pb2.Segment, _Mapping]] = ...) -> None: ... - -class CreateSegmentResponse(_message.Message): - __slots__ = [] - def __init__(self) -> None: ... - -class DeleteSegmentRequest(_message.Message): - __slots__ = ["id", "collection"] - ID_FIELD_NUMBER: _ClassVar[int] - COLLECTION_FIELD_NUMBER: _ClassVar[int] - id: str - collection: str - def __init__(self, id: _Optional[str] = ..., collection: _Optional[str] = ...) -> None: ... - -class DeleteSegmentResponse(_message.Message): - __slots__ = [] - def __init__(self) -> None: ... - -class GetSegmentsRequest(_message.Message): - __slots__ = ["id", "type", "scope", "collection"] - ID_FIELD_NUMBER: _ClassVar[int] - TYPE_FIELD_NUMBER: _ClassVar[int] - SCOPE_FIELD_NUMBER: _ClassVar[int] - COLLECTION_FIELD_NUMBER: _ClassVar[int] - id: str - type: str - scope: _chroma_pb2.SegmentScope - collection: str - def __init__(self, id: _Optional[str] = ..., type: _Optional[str] = ..., scope: _Optional[_Union[_chroma_pb2.SegmentScope, str]] = ..., collection: _Optional[str] = ...) -> None: ... - -class GetSegmentsResponse(_message.Message): - __slots__ = ["segments"] - SEGMENTS_FIELD_NUMBER: _ClassVar[int] - segments: _containers.RepeatedCompositeFieldContainer[_chroma_pb2.Segment] - def __init__(self, segments: _Optional[_Iterable[_Union[_chroma_pb2.Segment, _Mapping]]] = ...) -> None: ... - -class UpdateSegmentRequest(_message.Message): - __slots__ = ["id", "collection", "metadata", "reset_metadata"] - ID_FIELD_NUMBER: _ClassVar[int] - COLLECTION_FIELD_NUMBER: _ClassVar[int] - METADATA_FIELD_NUMBER: _ClassVar[int] - RESET_METADATA_FIELD_NUMBER: _ClassVar[int] - id: str - collection: str - metadata: _chroma_pb2.UpdateMetadata - reset_metadata: bool - def __init__(self, id: _Optional[str] = ..., collection: _Optional[str] = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., reset_metadata: bool = ...) -> None: ... - -class UpdateSegmentResponse(_message.Message): - __slots__ = [] - def __init__(self) -> None: ... - -class CreateCollectionRequest(_message.Message): - __slots__ = ["id", "name", "configuration_json_str", "metadata", "dimension", "get_or_create", "tenant", "database", "segments"] - ID_FIELD_NUMBER: _ClassVar[int] - NAME_FIELD_NUMBER: _ClassVar[int] - CONFIGURATION_JSON_STR_FIELD_NUMBER: _ClassVar[int] - METADATA_FIELD_NUMBER: _ClassVar[int] - DIMENSION_FIELD_NUMBER: _ClassVar[int] - GET_OR_CREATE_FIELD_NUMBER: _ClassVar[int] - TENANT_FIELD_NUMBER: _ClassVar[int] - DATABASE_FIELD_NUMBER: _ClassVar[int] - SEGMENTS_FIELD_NUMBER: _ClassVar[int] - id: str - name: str - configuration_json_str: str - metadata: _chroma_pb2.UpdateMetadata - dimension: int - get_or_create: bool - tenant: str - database: str - segments: _containers.RepeatedCompositeFieldContainer[_chroma_pb2.Segment] - def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., configuration_json_str: _Optional[str] = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., dimension: _Optional[int] = ..., get_or_create: bool = ..., tenant: _Optional[str] = ..., database: _Optional[str] = ..., segments: _Optional[_Iterable[_Union[_chroma_pb2.Segment, _Mapping]]] = ...) -> None: ... - -class CreateCollectionResponse(_message.Message): - __slots__ = ["collection", "created"] - COLLECTION_FIELD_NUMBER: _ClassVar[int] - CREATED_FIELD_NUMBER: _ClassVar[int] - collection: _chroma_pb2.Collection - created: bool - def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., created: bool = ...) -> None: ... - -class DeleteCollectionRequest(_message.Message): - __slots__ = ["id", "tenant", "database", "segment_ids"] - ID_FIELD_NUMBER: _ClassVar[int] - TENANT_FIELD_NUMBER: _ClassVar[int] - DATABASE_FIELD_NUMBER: _ClassVar[int] - SEGMENT_IDS_FIELD_NUMBER: _ClassVar[int] - id: str - tenant: str - database: str - segment_ids: _containers.RepeatedScalarFieldContainer[str] - def __init__(self, id: _Optional[str] = ..., tenant: _Optional[str] = ..., database: _Optional[str] = ..., segment_ids: _Optional[_Iterable[str]] = ...) -> None: ... - -class DeleteCollectionResponse(_message.Message): - __slots__ = [] - def __init__(self) -> None: ... - -class GetCollectionsRequest(_message.Message): - __slots__ = ["id", "name", "tenant", "database", "limit", "offset"] - ID_FIELD_NUMBER: _ClassVar[int] - NAME_FIELD_NUMBER: _ClassVar[int] - TENANT_FIELD_NUMBER: _ClassVar[int] - DATABASE_FIELD_NUMBER: _ClassVar[int] - LIMIT_FIELD_NUMBER: _ClassVar[int] - OFFSET_FIELD_NUMBER: _ClassVar[int] - id: str - name: str - tenant: str - database: str - limit: int - offset: int - def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., tenant: _Optional[str] = ..., database: _Optional[str] = ..., limit: _Optional[int] = ..., offset: _Optional[int] = ...) -> None: ... - -class GetCollectionsResponse(_message.Message): - __slots__ = ["collections"] - COLLECTIONS_FIELD_NUMBER: _ClassVar[int] - collections: _containers.RepeatedCompositeFieldContainer[_chroma_pb2.Collection] - def __init__(self, collections: _Optional[_Iterable[_Union[_chroma_pb2.Collection, _Mapping]]] = ...) -> None: ... - -class CheckCollectionsRequest(_message.Message): - __slots__ = ["collection_ids"] - COLLECTION_IDS_FIELD_NUMBER: _ClassVar[int] - collection_ids: _containers.RepeatedScalarFieldContainer[str] - def __init__(self, collection_ids: _Optional[_Iterable[str]] = ...) -> None: ... - -class CheckCollectionsResponse(_message.Message): - __slots__ = ["deleted"] - DELETED_FIELD_NUMBER: _ClassVar[int] - deleted: _containers.RepeatedScalarFieldContainer[bool] - def __init__(self, deleted: _Optional[_Iterable[bool]] = ...) -> None: ... - -class UpdateCollectionRequest(_message.Message): - __slots__ = ["id", "name", "dimension", "metadata", "reset_metadata"] - ID_FIELD_NUMBER: _ClassVar[int] - NAME_FIELD_NUMBER: _ClassVar[int] - DIMENSION_FIELD_NUMBER: _ClassVar[int] - METADATA_FIELD_NUMBER: _ClassVar[int] - RESET_METADATA_FIELD_NUMBER: _ClassVar[int] - id: str - name: str - dimension: int - metadata: _chroma_pb2.UpdateMetadata - reset_metadata: bool - def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., dimension: _Optional[int] = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ..., reset_metadata: bool = ...) -> None: ... - -class UpdateCollectionResponse(_message.Message): - __slots__ = [] - def __init__(self) -> None: ... - -class ResetStateResponse(_message.Message): - __slots__ = [] - def __init__(self) -> None: ... - -class GetLastCompactionTimeForTenantRequest(_message.Message): - __slots__ = ["tenant_id"] - TENANT_ID_FIELD_NUMBER: _ClassVar[int] - tenant_id: _containers.RepeatedScalarFieldContainer[str] - def __init__(self, tenant_id: _Optional[_Iterable[str]] = ...) -> None: ... - -class TenantLastCompactionTime(_message.Message): - __slots__ = ["tenant_id", "last_compaction_time"] - TENANT_ID_FIELD_NUMBER: _ClassVar[int] - LAST_COMPACTION_TIME_FIELD_NUMBER: _ClassVar[int] - tenant_id: str - last_compaction_time: int - def __init__(self, tenant_id: _Optional[str] = ..., last_compaction_time: _Optional[int] = ...) -> None: ... - -class GetLastCompactionTimeForTenantResponse(_message.Message): - __slots__ = ["tenant_last_compaction_time"] - TENANT_LAST_COMPACTION_TIME_FIELD_NUMBER: _ClassVar[int] - tenant_last_compaction_time: _containers.RepeatedCompositeFieldContainer[TenantLastCompactionTime] - def __init__(self, tenant_last_compaction_time: _Optional[_Iterable[_Union[TenantLastCompactionTime, _Mapping]]] = ...) -> None: ... - -class SetLastCompactionTimeForTenantRequest(_message.Message): - __slots__ = ["tenant_last_compaction_time"] - TENANT_LAST_COMPACTION_TIME_FIELD_NUMBER: _ClassVar[int] - tenant_last_compaction_time: TenantLastCompactionTime - def __init__(self, tenant_last_compaction_time: _Optional[_Union[TenantLastCompactionTime, _Mapping]] = ...) -> None: ... - -class FlushSegmentCompactionInfo(_message.Message): - __slots__ = ["segment_id", "file_paths"] - class FilePathsEntry(_message.Message): - __slots__ = ["key", "value"] - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: _chroma_pb2.FilePaths - def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[_chroma_pb2.FilePaths, _Mapping]] = ...) -> None: ... - SEGMENT_ID_FIELD_NUMBER: _ClassVar[int] - FILE_PATHS_FIELD_NUMBER: _ClassVar[int] - segment_id: str - file_paths: _containers.MessageMap[str, _chroma_pb2.FilePaths] - def __init__(self, segment_id: _Optional[str] = ..., file_paths: _Optional[_Mapping[str, _chroma_pb2.FilePaths]] = ...) -> None: ... - -class FlushCollectionCompactionRequest(_message.Message): - __slots__ = ["tenant_id", "collection_id", "log_position", "collection_version", "segment_compaction_info"] - TENANT_ID_FIELD_NUMBER: _ClassVar[int] - COLLECTION_ID_FIELD_NUMBER: _ClassVar[int] - LOG_POSITION_FIELD_NUMBER: _ClassVar[int] - COLLECTION_VERSION_FIELD_NUMBER: _ClassVar[int] - SEGMENT_COMPACTION_INFO_FIELD_NUMBER: _ClassVar[int] - tenant_id: str - collection_id: str - log_position: int - collection_version: int - segment_compaction_info: _containers.RepeatedCompositeFieldContainer[FlushSegmentCompactionInfo] - def __init__(self, tenant_id: _Optional[str] = ..., collection_id: _Optional[str] = ..., log_position: _Optional[int] = ..., collection_version: _Optional[int] = ..., segment_compaction_info: _Optional[_Iterable[_Union[FlushSegmentCompactionInfo, _Mapping]]] = ...) -> None: ... - -class FlushCollectionCompactionResponse(_message.Message): - __slots__ = ["collection_id", "collection_version", "last_compaction_time"] - COLLECTION_ID_FIELD_NUMBER: _ClassVar[int] - COLLECTION_VERSION_FIELD_NUMBER: _ClassVar[int] - LAST_COMPACTION_TIME_FIELD_NUMBER: _ClassVar[int] - collection_id: str - collection_version: int - last_compaction_time: int - def __init__(self, collection_id: _Optional[str] = ..., collection_version: _Optional[int] = ..., last_compaction_time: _Optional[int] = ...) -> None: ... diff --git a/chromadb/proto/coordinator_pb2_grpc.py b/chromadb/proto/coordinator_pb2_grpc.py deleted file mode 100644 index 557c834f24b..00000000000 --- a/chromadb/proto/coordinator_pb2_grpc.py +++ /dev/null @@ -1,595 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc - -from chromadb.proto import coordinator_pb2 as chromadb_dot_proto_dot_coordinator__pb2 -from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 - - -class SysDBStub(object): - """Missing associated documentation comment in .proto file.""" - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.CreateDatabase = channel.unary_unary( - '/chroma.SysDB/CreateDatabase', - request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseResponse.FromString, - ) - self.GetDatabase = channel.unary_unary( - '/chroma.SysDB/GetDatabase', - request_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseResponse.FromString, - ) - self.CreateTenant = channel.unary_unary( - '/chroma.SysDB/CreateTenant', - request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateTenantRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateTenantResponse.FromString, - ) - self.GetTenant = channel.unary_unary( - '/chroma.SysDB/GetTenant', - request_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetTenantRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetTenantResponse.FromString, - ) - self.CreateSegment = channel.unary_unary( - '/chroma.SysDB/CreateSegment', - request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentResponse.FromString, - ) - self.DeleteSegment = channel.unary_unary( - '/chroma.SysDB/DeleteSegment', - request_serializer=chromadb_dot_proto_dot_coordinator__pb2.DeleteSegmentRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.DeleteSegmentResponse.FromString, - ) - self.GetSegments = channel.unary_unary( - '/chroma.SysDB/GetSegments', - request_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsResponse.FromString, - ) - self.UpdateSegment = channel.unary_unary( - '/chroma.SysDB/UpdateSegment', - request_serializer=chromadb_dot_proto_dot_coordinator__pb2.UpdateSegmentRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.UpdateSegmentResponse.FromString, - ) - self.CreateCollection = channel.unary_unary( - '/chroma.SysDB/CreateCollection', - request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateCollectionRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateCollectionResponse.FromString, - ) - self.DeleteCollection = channel.unary_unary( - '/chroma.SysDB/DeleteCollection', - request_serializer=chromadb_dot_proto_dot_coordinator__pb2.DeleteCollectionRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.DeleteCollectionResponse.FromString, - ) - self.GetCollections = channel.unary_unary( - '/chroma.SysDB/GetCollections', - request_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetCollectionsRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetCollectionsResponse.FromString, - ) - self.CheckCollections = channel.unary_unary( - '/chroma.SysDB/CheckCollections', - request_serializer=chromadb_dot_proto_dot_coordinator__pb2.CheckCollectionsRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CheckCollectionsResponse.FromString, - ) - self.UpdateCollection = channel.unary_unary( - '/chroma.SysDB/UpdateCollection', - request_serializer=chromadb_dot_proto_dot_coordinator__pb2.UpdateCollectionRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.UpdateCollectionResponse.FromString, - ) - self.ResetState = channel.unary_unary( - '/chroma.SysDB/ResetState', - request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.ResetStateResponse.FromString, - ) - self.GetLastCompactionTimeForTenant = channel.unary_unary( - '/chroma.SysDB/GetLastCompactionTimeForTenant', - request_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetLastCompactionTimeForTenantRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetLastCompactionTimeForTenantResponse.FromString, - ) - self.SetLastCompactionTimeForTenant = channel.unary_unary( - '/chroma.SysDB/SetLastCompactionTimeForTenant', - request_serializer=chromadb_dot_proto_dot_coordinator__pb2.SetLastCompactionTimeForTenantRequest.SerializeToString, - response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - ) - self.FlushCollectionCompaction = channel.unary_unary( - '/chroma.SysDB/FlushCollectionCompaction', - request_serializer=chromadb_dot_proto_dot_coordinator__pb2.FlushCollectionCompactionRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_coordinator__pb2.FlushCollectionCompactionResponse.FromString, - ) - - -class SysDBServicer(object): - """Missing associated documentation comment in .proto file.""" - - def CreateDatabase(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetDatabase(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def CreateTenant(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetTenant(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def CreateSegment(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def DeleteSegment(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetSegments(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def UpdateSegment(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def CreateCollection(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def DeleteCollection(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetCollections(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def CheckCollections(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def UpdateCollection(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def ResetState(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetLastCompactionTimeForTenant(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def SetLastCompactionTimeForTenant(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def FlushCollectionCompaction(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_SysDBServicer_to_server(servicer, server): - rpc_method_handlers = { - 'CreateDatabase': grpc.unary_unary_rpc_method_handler( - servicer.CreateDatabase, - request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseRequest.FromString, - response_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseResponse.SerializeToString, - ), - 'GetDatabase': grpc.unary_unary_rpc_method_handler( - servicer.GetDatabase, - request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseRequest.FromString, - response_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseResponse.SerializeToString, - ), - 'CreateTenant': grpc.unary_unary_rpc_method_handler( - servicer.CreateTenant, - request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateTenantRequest.FromString, - response_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateTenantResponse.SerializeToString, - ), - 'GetTenant': grpc.unary_unary_rpc_method_handler( - servicer.GetTenant, - request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetTenantRequest.FromString, - response_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetTenantResponse.SerializeToString, - ), - 'CreateSegment': grpc.unary_unary_rpc_method_handler( - servicer.CreateSegment, - request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.FromString, - response_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentResponse.SerializeToString, - ), - 'DeleteSegment': grpc.unary_unary_rpc_method_handler( - servicer.DeleteSegment, - request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.DeleteSegmentRequest.FromString, - response_serializer=chromadb_dot_proto_dot_coordinator__pb2.DeleteSegmentResponse.SerializeToString, - ), - 'GetSegments': grpc.unary_unary_rpc_method_handler( - servicer.GetSegments, - request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsRequest.FromString, - response_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsResponse.SerializeToString, - ), - 'UpdateSegment': grpc.unary_unary_rpc_method_handler( - servicer.UpdateSegment, - request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.UpdateSegmentRequest.FromString, - response_serializer=chromadb_dot_proto_dot_coordinator__pb2.UpdateSegmentResponse.SerializeToString, - ), - 'CreateCollection': grpc.unary_unary_rpc_method_handler( - servicer.CreateCollection, - request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CreateCollectionRequest.FromString, - response_serializer=chromadb_dot_proto_dot_coordinator__pb2.CreateCollectionResponse.SerializeToString, - ), - 'DeleteCollection': grpc.unary_unary_rpc_method_handler( - servicer.DeleteCollection, - request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.DeleteCollectionRequest.FromString, - response_serializer=chromadb_dot_proto_dot_coordinator__pb2.DeleteCollectionResponse.SerializeToString, - ), - 'GetCollections': grpc.unary_unary_rpc_method_handler( - servicer.GetCollections, - request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetCollectionsRequest.FromString, - response_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetCollectionsResponse.SerializeToString, - ), - 'CheckCollections': grpc.unary_unary_rpc_method_handler( - servicer.CheckCollections, - request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.CheckCollectionsRequest.FromString, - response_serializer=chromadb_dot_proto_dot_coordinator__pb2.CheckCollectionsResponse.SerializeToString, - ), - 'UpdateCollection': grpc.unary_unary_rpc_method_handler( - servicer.UpdateCollection, - request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.UpdateCollectionRequest.FromString, - response_serializer=chromadb_dot_proto_dot_coordinator__pb2.UpdateCollectionResponse.SerializeToString, - ), - 'ResetState': grpc.unary_unary_rpc_method_handler( - servicer.ResetState, - request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - response_serializer=chromadb_dot_proto_dot_coordinator__pb2.ResetStateResponse.SerializeToString, - ), - 'GetLastCompactionTimeForTenant': grpc.unary_unary_rpc_method_handler( - servicer.GetLastCompactionTimeForTenant, - request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.GetLastCompactionTimeForTenantRequest.FromString, - response_serializer=chromadb_dot_proto_dot_coordinator__pb2.GetLastCompactionTimeForTenantResponse.SerializeToString, - ), - 'SetLastCompactionTimeForTenant': grpc.unary_unary_rpc_method_handler( - servicer.SetLastCompactionTimeForTenant, - request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.SetLastCompactionTimeForTenantRequest.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - ), - 'FlushCollectionCompaction': grpc.unary_unary_rpc_method_handler( - servicer.FlushCollectionCompaction, - request_deserializer=chromadb_dot_proto_dot_coordinator__pb2.FlushCollectionCompactionRequest.FromString, - response_serializer=chromadb_dot_proto_dot_coordinator__pb2.FlushCollectionCompactionResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'chroma.SysDB', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - - - # This class is part of an EXPERIMENTAL API. -class SysDB(object): - """Missing associated documentation comment in .proto file.""" - - @staticmethod - def CreateDatabase(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SysDB/CreateDatabase', - chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseRequest.SerializeToString, - chromadb_dot_proto_dot_coordinator__pb2.CreateDatabaseResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def GetDatabase(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SysDB/GetDatabase', - chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseRequest.SerializeToString, - chromadb_dot_proto_dot_coordinator__pb2.GetDatabaseResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def CreateTenant(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SysDB/CreateTenant', - chromadb_dot_proto_dot_coordinator__pb2.CreateTenantRequest.SerializeToString, - chromadb_dot_proto_dot_coordinator__pb2.CreateTenantResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def GetTenant(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SysDB/GetTenant', - chromadb_dot_proto_dot_coordinator__pb2.GetTenantRequest.SerializeToString, - chromadb_dot_proto_dot_coordinator__pb2.GetTenantResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def CreateSegment(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SysDB/CreateSegment', - chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentRequest.SerializeToString, - chromadb_dot_proto_dot_coordinator__pb2.CreateSegmentResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def DeleteSegment(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SysDB/DeleteSegment', - chromadb_dot_proto_dot_coordinator__pb2.DeleteSegmentRequest.SerializeToString, - chromadb_dot_proto_dot_coordinator__pb2.DeleteSegmentResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def GetSegments(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SysDB/GetSegments', - chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsRequest.SerializeToString, - chromadb_dot_proto_dot_coordinator__pb2.GetSegmentsResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def UpdateSegment(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SysDB/UpdateSegment', - chromadb_dot_proto_dot_coordinator__pb2.UpdateSegmentRequest.SerializeToString, - chromadb_dot_proto_dot_coordinator__pb2.UpdateSegmentResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def CreateCollection(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SysDB/CreateCollection', - chromadb_dot_proto_dot_coordinator__pb2.CreateCollectionRequest.SerializeToString, - chromadb_dot_proto_dot_coordinator__pb2.CreateCollectionResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def DeleteCollection(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SysDB/DeleteCollection', - chromadb_dot_proto_dot_coordinator__pb2.DeleteCollectionRequest.SerializeToString, - chromadb_dot_proto_dot_coordinator__pb2.DeleteCollectionResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def GetCollections(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SysDB/GetCollections', - chromadb_dot_proto_dot_coordinator__pb2.GetCollectionsRequest.SerializeToString, - chromadb_dot_proto_dot_coordinator__pb2.GetCollectionsResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def CheckCollections(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SysDB/CheckCollections', - chromadb_dot_proto_dot_coordinator__pb2.CheckCollectionsRequest.SerializeToString, - chromadb_dot_proto_dot_coordinator__pb2.CheckCollectionsResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def UpdateCollection(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SysDB/UpdateCollection', - chromadb_dot_proto_dot_coordinator__pb2.UpdateCollectionRequest.SerializeToString, - chromadb_dot_proto_dot_coordinator__pb2.UpdateCollectionResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def ResetState(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SysDB/ResetState', - google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - chromadb_dot_proto_dot_coordinator__pb2.ResetStateResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def GetLastCompactionTimeForTenant(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SysDB/GetLastCompactionTimeForTenant', - chromadb_dot_proto_dot_coordinator__pb2.GetLastCompactionTimeForTenantRequest.SerializeToString, - chromadb_dot_proto_dot_coordinator__pb2.GetLastCompactionTimeForTenantResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def SetLastCompactionTimeForTenant(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SysDB/SetLastCompactionTimeForTenant', - chromadb_dot_proto_dot_coordinator__pb2.SetLastCompactionTimeForTenantRequest.SerializeToString, - google_dot_protobuf_dot_empty__pb2.Empty.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def FlushCollectionCompaction(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.SysDB/FlushCollectionCompaction', - chromadb_dot_proto_dot_coordinator__pb2.FlushCollectionCompactionRequest.SerializeToString, - chromadb_dot_proto_dot_coordinator__pb2.FlushCollectionCompactionResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/chromadb/proto/logservice_pb2.py b/chromadb/proto/logservice_pb2.py deleted file mode 100644 index 51bfb042cd1..00000000000 --- a/chromadb/proto/logservice_pb2.py +++ /dev/null @@ -1,48 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: chromadb/proto/logservice.proto -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from chromadb.proto import chroma_pb2 as chromadb_dot_proto_dot_chroma__pb2 - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1f\x63hromadb/proto/logservice.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\"R\n\x0fPushLogsRequest\x12\x15\n\rcollection_id\x18\x01 \x01(\t\x12(\n\x07records\x18\x02 \x03(\x0b\x32\x17.chroma.OperationRecord\"(\n\x10PushLogsResponse\x12\x14\n\x0crecord_count\x18\x01 \x01(\x05\"n\n\x0fPullLogsRequest\x12\x15\n\rcollection_id\x18\x01 \x01(\t\x12\x19\n\x11start_from_offset\x18\x02 \x01(\x03\x12\x12\n\nbatch_size\x18\x03 \x01(\x05\x12\x15\n\rend_timestamp\x18\x04 \x01(\x03\"H\n\tLogRecord\x12\x12\n\nlog_offset\x18\x01 \x01(\x03\x12\'\n\x06record\x18\x02 \x01(\x0b\x32\x17.chroma.OperationRecord\"6\n\x10PullLogsResponse\x12\"\n\x07records\x18\x01 \x03(\x0b\x32\x11.chroma.LogRecord\"W\n\x0e\x43ollectionInfo\x12\x15\n\rcollection_id\x18\x01 \x01(\t\x12\x18\n\x10\x66irst_log_offset\x18\x02 \x01(\x03\x12\x14\n\x0c\x66irst_log_ts\x18\x03 \x01(\x03\"C\n$GetAllCollectionInfoToCompactRequest\x12\x1b\n\x13min_compaction_size\x18\x01 \x01(\x04\"\\\n%GetAllCollectionInfoToCompactResponse\x12\x33\n\x13\x61ll_collection_info\x18\x01 \x03(\x0b\x32\x16.chroma.CollectionInfo\"M\n UpdateCollectionLogOffsetRequest\x12\x15\n\rcollection_id\x18\x01 \x01(\t\x12\x12\n\nlog_offset\x18\x02 \x01(\x03\"#\n!UpdateCollectionLogOffsetResponse2\x82\x03\n\nLogService\x12?\n\x08PushLogs\x12\x17.chroma.PushLogsRequest\x1a\x18.chroma.PushLogsResponse\"\x00\x12?\n\x08PullLogs\x12\x17.chroma.PullLogsRequest\x1a\x18.chroma.PullLogsResponse\"\x00\x12~\n\x1dGetAllCollectionInfoToCompact\x12,.chroma.GetAllCollectionInfoToCompactRequest\x1a-.chroma.GetAllCollectionInfoToCompactResponse\"\x00\x12r\n\x19UpdateCollectionLogOffset\x12(.chroma.UpdateCollectionLogOffsetRequest\x1a).chroma.UpdateCollectionLogOffsetResponse\"\x00\x42\x39Z7github.com/chroma-core/chroma/go/pkg/proto/logservicepbb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'chromadb.proto.logservice_pb2', _globals) -if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'Z7github.com/chroma-core/chroma/go/pkg/proto/logservicepb' - _globals['_PUSHLOGSREQUEST']._serialized_start=72 - _globals['_PUSHLOGSREQUEST']._serialized_end=154 - _globals['_PUSHLOGSRESPONSE']._serialized_start=156 - _globals['_PUSHLOGSRESPONSE']._serialized_end=196 - _globals['_PULLLOGSREQUEST']._serialized_start=198 - _globals['_PULLLOGSREQUEST']._serialized_end=308 - _globals['_LOGRECORD']._serialized_start=310 - _globals['_LOGRECORD']._serialized_end=382 - _globals['_PULLLOGSRESPONSE']._serialized_start=384 - _globals['_PULLLOGSRESPONSE']._serialized_end=438 - _globals['_COLLECTIONINFO']._serialized_start=440 - _globals['_COLLECTIONINFO']._serialized_end=527 - _globals['_GETALLCOLLECTIONINFOTOCOMPACTREQUEST']._serialized_start=529 - _globals['_GETALLCOLLECTIONINFOTOCOMPACTREQUEST']._serialized_end=596 - _globals['_GETALLCOLLECTIONINFOTOCOMPACTRESPONSE']._serialized_start=598 - _globals['_GETALLCOLLECTIONINFOTOCOMPACTRESPONSE']._serialized_end=690 - _globals['_UPDATECOLLECTIONLOGOFFSETREQUEST']._serialized_start=692 - _globals['_UPDATECOLLECTIONLOGOFFSETREQUEST']._serialized_end=769 - _globals['_UPDATECOLLECTIONLOGOFFSETRESPONSE']._serialized_start=771 - _globals['_UPDATECOLLECTIONLOGOFFSETRESPONSE']._serialized_end=806 - _globals['_LOGSERVICE']._serialized_start=809 - _globals['_LOGSERVICE']._serialized_end=1195 -# @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/logservice_pb2.pyi b/chromadb/proto/logservice_pb2.pyi deleted file mode 100644 index a1d8f77b06e..00000000000 --- a/chromadb/proto/logservice_pb2.pyi +++ /dev/null @@ -1,81 +0,0 @@ -from chromadb.proto import chroma_pb2 as _chroma_pb2 -from google.protobuf.internal import containers as _containers -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union - -DESCRIPTOR: _descriptor.FileDescriptor - -class PushLogsRequest(_message.Message): - __slots__ = ["collection_id", "records"] - COLLECTION_ID_FIELD_NUMBER: _ClassVar[int] - RECORDS_FIELD_NUMBER: _ClassVar[int] - collection_id: str - records: _containers.RepeatedCompositeFieldContainer[_chroma_pb2.OperationRecord] - def __init__(self, collection_id: _Optional[str] = ..., records: _Optional[_Iterable[_Union[_chroma_pb2.OperationRecord, _Mapping]]] = ...) -> None: ... - -class PushLogsResponse(_message.Message): - __slots__ = ["record_count"] - RECORD_COUNT_FIELD_NUMBER: _ClassVar[int] - record_count: int - def __init__(self, record_count: _Optional[int] = ...) -> None: ... - -class PullLogsRequest(_message.Message): - __slots__ = ["collection_id", "start_from_offset", "batch_size", "end_timestamp"] - COLLECTION_ID_FIELD_NUMBER: _ClassVar[int] - START_FROM_OFFSET_FIELD_NUMBER: _ClassVar[int] - BATCH_SIZE_FIELD_NUMBER: _ClassVar[int] - END_TIMESTAMP_FIELD_NUMBER: _ClassVar[int] - collection_id: str - start_from_offset: int - batch_size: int - end_timestamp: int - def __init__(self, collection_id: _Optional[str] = ..., start_from_offset: _Optional[int] = ..., batch_size: _Optional[int] = ..., end_timestamp: _Optional[int] = ...) -> None: ... - -class LogRecord(_message.Message): - __slots__ = ["log_offset", "record"] - LOG_OFFSET_FIELD_NUMBER: _ClassVar[int] - RECORD_FIELD_NUMBER: _ClassVar[int] - log_offset: int - record: _chroma_pb2.OperationRecord - def __init__(self, log_offset: _Optional[int] = ..., record: _Optional[_Union[_chroma_pb2.OperationRecord, _Mapping]] = ...) -> None: ... - -class PullLogsResponse(_message.Message): - __slots__ = ["records"] - RECORDS_FIELD_NUMBER: _ClassVar[int] - records: _containers.RepeatedCompositeFieldContainer[LogRecord] - def __init__(self, records: _Optional[_Iterable[_Union[LogRecord, _Mapping]]] = ...) -> None: ... - -class CollectionInfo(_message.Message): - __slots__ = ["collection_id", "first_log_offset", "first_log_ts"] - COLLECTION_ID_FIELD_NUMBER: _ClassVar[int] - FIRST_LOG_OFFSET_FIELD_NUMBER: _ClassVar[int] - FIRST_LOG_TS_FIELD_NUMBER: _ClassVar[int] - collection_id: str - first_log_offset: int - first_log_ts: int - def __init__(self, collection_id: _Optional[str] = ..., first_log_offset: _Optional[int] = ..., first_log_ts: _Optional[int] = ...) -> None: ... - -class GetAllCollectionInfoToCompactRequest(_message.Message): - __slots__ = ["min_compaction_size"] - MIN_COMPACTION_SIZE_FIELD_NUMBER: _ClassVar[int] - min_compaction_size: int - def __init__(self, min_compaction_size: _Optional[int] = ...) -> None: ... - -class GetAllCollectionInfoToCompactResponse(_message.Message): - __slots__ = ["all_collection_info"] - ALL_COLLECTION_INFO_FIELD_NUMBER: _ClassVar[int] - all_collection_info: _containers.RepeatedCompositeFieldContainer[CollectionInfo] - def __init__(self, all_collection_info: _Optional[_Iterable[_Union[CollectionInfo, _Mapping]]] = ...) -> None: ... - -class UpdateCollectionLogOffsetRequest(_message.Message): - __slots__ = ["collection_id", "log_offset"] - COLLECTION_ID_FIELD_NUMBER: _ClassVar[int] - LOG_OFFSET_FIELD_NUMBER: _ClassVar[int] - collection_id: str - log_offset: int - def __init__(self, collection_id: _Optional[str] = ..., log_offset: _Optional[int] = ...) -> None: ... - -class UpdateCollectionLogOffsetResponse(_message.Message): - __slots__ = [] - def __init__(self) -> None: ... diff --git a/chromadb/proto/logservice_pb2_grpc.py b/chromadb/proto/logservice_pb2_grpc.py deleted file mode 100644 index 1044460c60d..00000000000 --- a/chromadb/proto/logservice_pb2_grpc.py +++ /dev/null @@ -1,165 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc - -from chromadb.proto import logservice_pb2 as chromadb_dot_proto_dot_logservice__pb2 - - -class LogServiceStub(object): - """Missing associated documentation comment in .proto file.""" - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.PushLogs = channel.unary_unary( - '/chroma.LogService/PushLogs', - request_serializer=chromadb_dot_proto_dot_logservice__pb2.PushLogsRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_logservice__pb2.PushLogsResponse.FromString, - ) - self.PullLogs = channel.unary_unary( - '/chroma.LogService/PullLogs', - request_serializer=chromadb_dot_proto_dot_logservice__pb2.PullLogsRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_logservice__pb2.PullLogsResponse.FromString, - ) - self.GetAllCollectionInfoToCompact = channel.unary_unary( - '/chroma.LogService/GetAllCollectionInfoToCompact', - request_serializer=chromadb_dot_proto_dot_logservice__pb2.GetAllCollectionInfoToCompactRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_logservice__pb2.GetAllCollectionInfoToCompactResponse.FromString, - ) - self.UpdateCollectionLogOffset = channel.unary_unary( - '/chroma.LogService/UpdateCollectionLogOffset', - request_serializer=chromadb_dot_proto_dot_logservice__pb2.UpdateCollectionLogOffsetRequest.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_logservice__pb2.UpdateCollectionLogOffsetResponse.FromString, - ) - - -class LogServiceServicer(object): - """Missing associated documentation comment in .proto file.""" - - def PushLogs(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def PullLogs(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetAllCollectionInfoToCompact(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def UpdateCollectionLogOffset(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_LogServiceServicer_to_server(servicer, server): - rpc_method_handlers = { - 'PushLogs': grpc.unary_unary_rpc_method_handler( - servicer.PushLogs, - request_deserializer=chromadb_dot_proto_dot_logservice__pb2.PushLogsRequest.FromString, - response_serializer=chromadb_dot_proto_dot_logservice__pb2.PushLogsResponse.SerializeToString, - ), - 'PullLogs': grpc.unary_unary_rpc_method_handler( - servicer.PullLogs, - request_deserializer=chromadb_dot_proto_dot_logservice__pb2.PullLogsRequest.FromString, - response_serializer=chromadb_dot_proto_dot_logservice__pb2.PullLogsResponse.SerializeToString, - ), - 'GetAllCollectionInfoToCompact': grpc.unary_unary_rpc_method_handler( - servicer.GetAllCollectionInfoToCompact, - request_deserializer=chromadb_dot_proto_dot_logservice__pb2.GetAllCollectionInfoToCompactRequest.FromString, - response_serializer=chromadb_dot_proto_dot_logservice__pb2.GetAllCollectionInfoToCompactResponse.SerializeToString, - ), - 'UpdateCollectionLogOffset': grpc.unary_unary_rpc_method_handler( - servicer.UpdateCollectionLogOffset, - request_deserializer=chromadb_dot_proto_dot_logservice__pb2.UpdateCollectionLogOffsetRequest.FromString, - response_serializer=chromadb_dot_proto_dot_logservice__pb2.UpdateCollectionLogOffsetResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'chroma.LogService', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - - - # This class is part of an EXPERIMENTAL API. -class LogService(object): - """Missing associated documentation comment in .proto file.""" - - @staticmethod - def PushLogs(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.LogService/PushLogs', - chromadb_dot_proto_dot_logservice__pb2.PushLogsRequest.SerializeToString, - chromadb_dot_proto_dot_logservice__pb2.PushLogsResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def PullLogs(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.LogService/PullLogs', - chromadb_dot_proto_dot_logservice__pb2.PullLogsRequest.SerializeToString, - chromadb_dot_proto_dot_logservice__pb2.PullLogsResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def GetAllCollectionInfoToCompact(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.LogService/GetAllCollectionInfoToCompact', - chromadb_dot_proto_dot_logservice__pb2.GetAllCollectionInfoToCompactRequest.SerializeToString, - chromadb_dot_proto_dot_logservice__pb2.GetAllCollectionInfoToCompactResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def UpdateCollectionLogOffset(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.LogService/UpdateCollectionLogOffset', - chromadb_dot_proto_dot_logservice__pb2.UpdateCollectionLogOffsetRequest.SerializeToString, - chromadb_dot_proto_dot_logservice__pb2.UpdateCollectionLogOffsetResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/chromadb/proto/query_executor_pb2.py b/chromadb/proto/query_executor_pb2.py deleted file mode 100644 index d89a21c189a..00000000000 --- a/chromadb/proto/query_executor_pb2.py +++ /dev/null @@ -1,57 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: chromadb/proto/query_executor.proto -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from chromadb.proto import chroma_pb2 as chromadb_dot_proto_dot_chroma__pb2 - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n#chromadb/proto/query_executor.proto\x12\x06\x63hroma\x1a\x1b\x63hromadb/proto/chroma.proto\"n\n\x0cScanOperator\x12&\n\ncollection\x18\x01 \x01(\x0b\x32\x12.chroma.Collection\x12\x0e\n\x06knn_id\x18\x02 \x01(\t\x12\x13\n\x0bmetadata_id\x18\x03 \x01(\t\x12\x11\n\trecord_id\x18\x04 \x01(\t\"\xaf\x01\n\x0e\x46ilterOperator\x12!\n\x03ids\x18\x01 \x01(\x0b\x32\x0f.chroma.UserIdsH\x00\x88\x01\x01\x12!\n\x05where\x18\x02 \x01(\x0b\x32\r.chroma.WhereH\x01\x88\x01\x01\x12\x32\n\x0ewhere_document\x18\x03 \x01(\x0b\x32\x15.chroma.WhereDocumentH\x02\x88\x01\x01\x42\x06\n\x04_idsB\x08\n\x06_whereB\x11\n\x0f_where_document\"@\n\x0bKNNOperator\x12\"\n\nembeddings\x18\x01 \x03(\x0b\x32\x0e.chroma.Vector\x12\r\n\x05\x66\x65tch\x18\x02 \x01(\r\";\n\rLimitOperator\x12\x0c\n\x04skip\x18\x01 \x01(\r\x12\x12\n\x05\x66\x65tch\x18\x02 \x01(\rH\x00\x88\x01\x01\x42\x08\n\x06_fetch\"K\n\x12ProjectionOperator\x12\x10\n\x08\x64ocument\x18\x01 \x01(\x08\x12\x11\n\tembedding\x18\x02 \x01(\x08\x12\x10\n\x08metadata\x18\x03 \x01(\x08\"Y\n\x15KNNProjectionOperator\x12.\n\nprojection\x18\x01 \x01(\x0b\x32\x1a.chroma.ProjectionOperator\x12\x10\n\x08\x64istance\x18\x02 \x01(\x08\"/\n\tCountPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\"\x1c\n\x0b\x43ountResult\x12\r\n\x05\x63ount\x18\x01 \x01(\r\"\xab\x01\n\x07GetPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\x12&\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x16.chroma.FilterOperator\x12$\n\x05limit\x18\x03 \x01(\x0b\x32\x15.chroma.LimitOperator\x12.\n\nprojection\x18\x04 \x01(\x0b\x32\x1a.chroma.ProjectionOperator\"\xb4\x01\n\x10ProjectionRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x08\x64ocument\x18\x02 \x01(\tH\x00\x88\x01\x01\x12&\n\tembedding\x18\x03 \x01(\x0b\x32\x0e.chroma.VectorH\x01\x88\x01\x01\x12-\n\x08metadata\x18\x04 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x02\x88\x01\x01\x42\x0b\n\t_documentB\x0c\n\n_embeddingB\x0b\n\t_metadata\"6\n\tGetResult\x12)\n\x07records\x18\x01 \x03(\x0b\x32\x18.chroma.ProjectionRecord\"\xaa\x01\n\x07KNNPlan\x12\"\n\x04scan\x18\x01 \x01(\x0b\x32\x14.chroma.ScanOperator\x12&\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x16.chroma.FilterOperator\x12 \n\x03knn\x18\x03 \x01(\x0b\x32\x13.chroma.KNNOperator\x12\x31\n\nprojection\x18\x04 \x01(\x0b\x32\x1d.chroma.KNNProjectionOperator\"c\n\x13KNNProjectionRecord\x12(\n\x06record\x18\x01 \x01(\x0b\x32\x18.chroma.ProjectionRecord\x12\x15\n\x08\x64istance\x18\x02 \x01(\x02H\x00\x88\x01\x01\x42\x0b\n\t_distance\"9\n\tKNNResult\x12,\n\x07records\x18\x01 \x03(\x0b\x32\x1b.chroma.KNNProjectionRecord\"4\n\x0eKNNBatchResult\x12\"\n\x07results\x18\x01 \x03(\x0b\x32\x11.chroma.KNNResult2\xa1\x01\n\rQueryExecutor\x12\x31\n\x05\x43ount\x12\x11.chroma.CountPlan\x1a\x13.chroma.CountResult\"\x00\x12+\n\x03Get\x12\x0f.chroma.GetPlan\x1a\x11.chroma.GetResult\"\x00\x12\x30\n\x03KNN\x12\x0f.chroma.KNNPlan\x1a\x16.chroma.KNNBatchResult\"\x00\x62\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'chromadb.proto.query_executor_pb2', _globals) -if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - _globals['_SCANOPERATOR']._serialized_start=76 - _globals['_SCANOPERATOR']._serialized_end=186 - _globals['_FILTEROPERATOR']._serialized_start=189 - _globals['_FILTEROPERATOR']._serialized_end=364 - _globals['_KNNOPERATOR']._serialized_start=366 - _globals['_KNNOPERATOR']._serialized_end=430 - _globals['_LIMITOPERATOR']._serialized_start=432 - _globals['_LIMITOPERATOR']._serialized_end=491 - _globals['_PROJECTIONOPERATOR']._serialized_start=493 - _globals['_PROJECTIONOPERATOR']._serialized_end=568 - _globals['_KNNPROJECTIONOPERATOR']._serialized_start=570 - _globals['_KNNPROJECTIONOPERATOR']._serialized_end=659 - _globals['_COUNTPLAN']._serialized_start=661 - _globals['_COUNTPLAN']._serialized_end=708 - _globals['_COUNTRESULT']._serialized_start=710 - _globals['_COUNTRESULT']._serialized_end=738 - _globals['_GETPLAN']._serialized_start=741 - _globals['_GETPLAN']._serialized_end=912 - _globals['_PROJECTIONRECORD']._serialized_start=915 - _globals['_PROJECTIONRECORD']._serialized_end=1095 - _globals['_GETRESULT']._serialized_start=1097 - _globals['_GETRESULT']._serialized_end=1151 - _globals['_KNNPLAN']._serialized_start=1154 - _globals['_KNNPLAN']._serialized_end=1324 - _globals['_KNNPROJECTIONRECORD']._serialized_start=1326 - _globals['_KNNPROJECTIONRECORD']._serialized_end=1425 - _globals['_KNNRESULT']._serialized_start=1427 - _globals['_KNNRESULT']._serialized_end=1484 - _globals['_KNNBATCHRESULT']._serialized_start=1486 - _globals['_KNNBATCHRESULT']._serialized_end=1538 - _globals['_QUERYEXECUTOR']._serialized_start=1541 - _globals['_QUERYEXECUTOR']._serialized_end=1702 -# @@protoc_insertion_point(module_scope) diff --git a/chromadb/proto/query_executor_pb2.pyi b/chromadb/proto/query_executor_pb2.pyi deleted file mode 100644 index 53483f8445f..00000000000 --- a/chromadb/proto/query_executor_pb2.pyi +++ /dev/null @@ -1,137 +0,0 @@ -from chromadb.proto import chroma_pb2 as _chroma_pb2 -from google.protobuf.internal import containers as _containers -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union - -DESCRIPTOR: _descriptor.FileDescriptor - -class ScanOperator(_message.Message): - __slots__ = ["collection", "knn_id", "metadata_id", "record_id"] - COLLECTION_FIELD_NUMBER: _ClassVar[int] - KNN_ID_FIELD_NUMBER: _ClassVar[int] - METADATA_ID_FIELD_NUMBER: _ClassVar[int] - RECORD_ID_FIELD_NUMBER: _ClassVar[int] - collection: _chroma_pb2.Collection - knn_id: str - metadata_id: str - record_id: str - def __init__(self, collection: _Optional[_Union[_chroma_pb2.Collection, _Mapping]] = ..., knn_id: _Optional[str] = ..., metadata_id: _Optional[str] = ..., record_id: _Optional[str] = ...) -> None: ... - -class FilterOperator(_message.Message): - __slots__ = ["ids", "where", "where_document"] - IDS_FIELD_NUMBER: _ClassVar[int] - WHERE_FIELD_NUMBER: _ClassVar[int] - WHERE_DOCUMENT_FIELD_NUMBER: _ClassVar[int] - ids: _chroma_pb2.UserIds - where: _chroma_pb2.Where - where_document: _chroma_pb2.WhereDocument - def __init__(self, ids: _Optional[_Union[_chroma_pb2.UserIds, _Mapping]] = ..., where: _Optional[_Union[_chroma_pb2.Where, _Mapping]] = ..., where_document: _Optional[_Union[_chroma_pb2.WhereDocument, _Mapping]] = ...) -> None: ... - -class KNNOperator(_message.Message): - __slots__ = ["embeddings", "fetch"] - EMBEDDINGS_FIELD_NUMBER: _ClassVar[int] - FETCH_FIELD_NUMBER: _ClassVar[int] - embeddings: _containers.RepeatedCompositeFieldContainer[_chroma_pb2.Vector] - fetch: int - def __init__(self, embeddings: _Optional[_Iterable[_Union[_chroma_pb2.Vector, _Mapping]]] = ..., fetch: _Optional[int] = ...) -> None: ... - -class LimitOperator(_message.Message): - __slots__ = ["skip", "fetch"] - SKIP_FIELD_NUMBER: _ClassVar[int] - FETCH_FIELD_NUMBER: _ClassVar[int] - skip: int - fetch: int - def __init__(self, skip: _Optional[int] = ..., fetch: _Optional[int] = ...) -> None: ... - -class ProjectionOperator(_message.Message): - __slots__ = ["document", "embedding", "metadata"] - DOCUMENT_FIELD_NUMBER: _ClassVar[int] - EMBEDDING_FIELD_NUMBER: _ClassVar[int] - METADATA_FIELD_NUMBER: _ClassVar[int] - document: bool - embedding: bool - metadata: bool - def __init__(self, document: bool = ..., embedding: bool = ..., metadata: bool = ...) -> None: ... - -class KNNProjectionOperator(_message.Message): - __slots__ = ["projection", "distance"] - PROJECTION_FIELD_NUMBER: _ClassVar[int] - DISTANCE_FIELD_NUMBER: _ClassVar[int] - projection: ProjectionOperator - distance: bool - def __init__(self, projection: _Optional[_Union[ProjectionOperator, _Mapping]] = ..., distance: bool = ...) -> None: ... - -class CountPlan(_message.Message): - __slots__ = ["scan"] - SCAN_FIELD_NUMBER: _ClassVar[int] - scan: ScanOperator - def __init__(self, scan: _Optional[_Union[ScanOperator, _Mapping]] = ...) -> None: ... - -class CountResult(_message.Message): - __slots__ = ["count"] - COUNT_FIELD_NUMBER: _ClassVar[int] - count: int - def __init__(self, count: _Optional[int] = ...) -> None: ... - -class GetPlan(_message.Message): - __slots__ = ["scan", "filter", "limit", "projection"] - SCAN_FIELD_NUMBER: _ClassVar[int] - FILTER_FIELD_NUMBER: _ClassVar[int] - LIMIT_FIELD_NUMBER: _ClassVar[int] - PROJECTION_FIELD_NUMBER: _ClassVar[int] - scan: ScanOperator - filter: FilterOperator - limit: LimitOperator - projection: ProjectionOperator - def __init__(self, scan: _Optional[_Union[ScanOperator, _Mapping]] = ..., filter: _Optional[_Union[FilterOperator, _Mapping]] = ..., limit: _Optional[_Union[LimitOperator, _Mapping]] = ..., projection: _Optional[_Union[ProjectionOperator, _Mapping]] = ...) -> None: ... - -class ProjectionRecord(_message.Message): - __slots__ = ["id", "document", "embedding", "metadata"] - ID_FIELD_NUMBER: _ClassVar[int] - DOCUMENT_FIELD_NUMBER: _ClassVar[int] - EMBEDDING_FIELD_NUMBER: _ClassVar[int] - METADATA_FIELD_NUMBER: _ClassVar[int] - id: str - document: str - embedding: _chroma_pb2.Vector - metadata: _chroma_pb2.UpdateMetadata - def __init__(self, id: _Optional[str] = ..., document: _Optional[str] = ..., embedding: _Optional[_Union[_chroma_pb2.Vector, _Mapping]] = ..., metadata: _Optional[_Union[_chroma_pb2.UpdateMetadata, _Mapping]] = ...) -> None: ... - -class GetResult(_message.Message): - __slots__ = ["records"] - RECORDS_FIELD_NUMBER: _ClassVar[int] - records: _containers.RepeatedCompositeFieldContainer[ProjectionRecord] - def __init__(self, records: _Optional[_Iterable[_Union[ProjectionRecord, _Mapping]]] = ...) -> None: ... - -class KNNPlan(_message.Message): - __slots__ = ["scan", "filter", "knn", "projection"] - SCAN_FIELD_NUMBER: _ClassVar[int] - FILTER_FIELD_NUMBER: _ClassVar[int] - KNN_FIELD_NUMBER: _ClassVar[int] - PROJECTION_FIELD_NUMBER: _ClassVar[int] - scan: ScanOperator - filter: FilterOperator - knn: KNNOperator - projection: KNNProjectionOperator - def __init__(self, scan: _Optional[_Union[ScanOperator, _Mapping]] = ..., filter: _Optional[_Union[FilterOperator, _Mapping]] = ..., knn: _Optional[_Union[KNNOperator, _Mapping]] = ..., projection: _Optional[_Union[KNNProjectionOperator, _Mapping]] = ...) -> None: ... - -class KNNProjectionRecord(_message.Message): - __slots__ = ["record", "distance"] - RECORD_FIELD_NUMBER: _ClassVar[int] - DISTANCE_FIELD_NUMBER: _ClassVar[int] - record: ProjectionRecord - distance: float - def __init__(self, record: _Optional[_Union[ProjectionRecord, _Mapping]] = ..., distance: _Optional[float] = ...) -> None: ... - -class KNNResult(_message.Message): - __slots__ = ["records"] - RECORDS_FIELD_NUMBER: _ClassVar[int] - records: _containers.RepeatedCompositeFieldContainer[KNNProjectionRecord] - def __init__(self, records: _Optional[_Iterable[_Union[KNNProjectionRecord, _Mapping]]] = ...) -> None: ... - -class KNNBatchResult(_message.Message): - __slots__ = ["results"] - RESULTS_FIELD_NUMBER: _ClassVar[int] - results: _containers.RepeatedCompositeFieldContainer[KNNResult] - def __init__(self, results: _Optional[_Iterable[_Union[KNNResult, _Mapping]]] = ...) -> None: ... diff --git a/chromadb/proto/query_executor_pb2_grpc.py b/chromadb/proto/query_executor_pb2_grpc.py deleted file mode 100644 index d1333518dfa..00000000000 --- a/chromadb/proto/query_executor_pb2_grpc.py +++ /dev/null @@ -1,132 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc - -from chromadb.proto import query_executor_pb2 as chromadb_dot_proto_dot_query__executor__pb2 - - -class QueryExecutorStub(object): - """Missing associated documentation comment in .proto file.""" - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.Count = channel.unary_unary( - '/chroma.QueryExecutor/Count', - request_serializer=chromadb_dot_proto_dot_query__executor__pb2.CountPlan.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_query__executor__pb2.CountResult.FromString, - ) - self.Get = channel.unary_unary( - '/chroma.QueryExecutor/Get', - request_serializer=chromadb_dot_proto_dot_query__executor__pb2.GetPlan.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_query__executor__pb2.GetResult.FromString, - ) - self.KNN = channel.unary_unary( - '/chroma.QueryExecutor/KNN', - request_serializer=chromadb_dot_proto_dot_query__executor__pb2.KNNPlan.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_query__executor__pb2.KNNBatchResult.FromString, - ) - - -class QueryExecutorServicer(object): - """Missing associated documentation comment in .proto file.""" - - def Count(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Get(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def KNN(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_QueryExecutorServicer_to_server(servicer, server): - rpc_method_handlers = { - 'Count': grpc.unary_unary_rpc_method_handler( - servicer.Count, - request_deserializer=chromadb_dot_proto_dot_query__executor__pb2.CountPlan.FromString, - response_serializer=chromadb_dot_proto_dot_query__executor__pb2.CountResult.SerializeToString, - ), - 'Get': grpc.unary_unary_rpc_method_handler( - servicer.Get, - request_deserializer=chromadb_dot_proto_dot_query__executor__pb2.GetPlan.FromString, - response_serializer=chromadb_dot_proto_dot_query__executor__pb2.GetResult.SerializeToString, - ), - 'KNN': grpc.unary_unary_rpc_method_handler( - servicer.KNN, - request_deserializer=chromadb_dot_proto_dot_query__executor__pb2.KNNPlan.FromString, - response_serializer=chromadb_dot_proto_dot_query__executor__pb2.KNNBatchResult.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'chroma.QueryExecutor', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - - - # This class is part of an EXPERIMENTAL API. -class QueryExecutor(object): - """Missing associated documentation comment in .proto file.""" - - @staticmethod - def Count(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.QueryExecutor/Count', - chromadb_dot_proto_dot_query__executor__pb2.CountPlan.SerializeToString, - chromadb_dot_proto_dot_query__executor__pb2.CountResult.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def Get(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.QueryExecutor/Get', - chromadb_dot_proto_dot_query__executor__pb2.GetPlan.SerializeToString, - chromadb_dot_proto_dot_query__executor__pb2.GetResult.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def KNN(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/chroma.QueryExecutor/KNN', - chromadb_dot_proto_dot_query__executor__pb2.KNNPlan.SerializeToString, - chromadb_dot_proto_dot_query__executor__pb2.KNNBatchResult.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/chromadb/segment/impl/manager/distributed.py b/chromadb/segment/impl/manager/distributed.py index 033cf375e9c..7acdaa4d860 100644 --- a/chromadb/segment/impl/manager/distributed.py +++ b/chromadb/segment/impl/manager/distributed.py @@ -1,4 +1,3 @@ -from collections import defaultdict from threading import Lock from typing import Dict, Sequence from uuid import UUID, uuid4 @@ -19,7 +18,7 @@ OpenTelemetryGranularity, trace_method, ) -from chromadb.types import Collection, Operation, Segment, SegmentScope +from chromadb.types import Collection, CollectionAndSegments, Operation, Segment, SegmentScope class DistributedSegmentManager(SegmentManager): @@ -27,9 +26,6 @@ class DistributedSegmentManager(SegmentManager): _system: System _opentelemetry_client: OpenTelemetryClient _instances: Dict[UUID, SegmentImplementation] - _segment_cache: Dict[ - UUID, Dict[SegmentScope, Segment] - ] # collection_id -> scope -> segment _segment_directory: SegmentDirectory _lock: Lock # _segment_server_stubs: Dict[str, SegmentServerStub] # grpc_url -> grpc stub @@ -41,7 +37,6 @@ def __init__(self, system: System): self._system = system self._opentelemetry_client = system.require(OpenTelemetryClient) self._instances = {} - self._segment_cache = defaultdict(dict) self._lock = Lock() @trace_method( @@ -60,6 +55,7 @@ def prepare_segments_for_new_collection( metadata=PersistentHnswParams.extract(collection.metadata) if collection.metadata else None, + file_paths={}, ) metadata_segment = Segment( id=uuid4(), @@ -67,6 +63,7 @@ def prepare_segments_for_new_collection( scope=SegmentScope.METADATA, collection=collection.id, metadata=None, + file_paths={}, ) record_segment = Segment( id=uuid4(), @@ -74,6 +71,7 @@ def prepare_segments_for_new_collection( scope=SegmentScope.RECORD, collection=collection.id, metadata=None, + file_paths={}, ) return [vector_segment, record_segment, metadata_segment] @@ -82,27 +80,12 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: segments = self._sysdb.get_segments(collection=collection_id) return [s["id"] for s in segments] - @trace_method( - "DistributedSegmentManager.get_segment", - OpenTelemetryGranularity.OPERATION_AND_SEGMENT, - ) - def get_segment(self, collection_id: UUID, scope: SegmentScope) -> Segment: - if scope not in self._segment_cache[collection_id]: - # For now, there is exactly one segment per scope for a given collection - segment = self._sysdb.get_segments(collection=collection_id, scope=scope)[0] - self._segment_cache[collection_id][scope] = segment - return self._segment_cache[collection_id][scope] - @trace_method( "DistributedSegmentManager.get_endpoint", OpenTelemetryGranularity.OPERATION_AND_SEGMENT, ) - def get_endpoint(self, collection_id: UUID) -> str: - # Get grpc endpoint from record segment. Since grpc endpoint is endpoint is - # determined by collection uuid, the endpoint should be the same for all - # segments of the same collection - record_segment = self.get_segment(collection_id, SegmentScope.RECORD) - return self._segment_directory.get_segment_endpoint(record_segment) + def get_endpoint(self, segment: Segment) -> str: + return self._segment_directory.get_segment_endpoint(segment) @trace_method( "DistributedSegmentManager.hint_use_collection", diff --git a/chromadb/segment/impl/manager/local.py b/chromadb/segment/impl/manager/local.py index 296ace7f9e7..06b152cc7f3 100644 --- a/chromadb/segment/impl/manager/local.py +++ b/chromadb/segment/impl/manager/local.py @@ -263,4 +263,5 @@ def _segment(type: SegmentType, scope: SegmentScope, collection: Collection) -> scope=scope, collection=collection.id, metadata=metadata, + file_paths={}, ) diff --git a/chromadb/segment/impl/metadata/grpc_segment.py b/chromadb/segment/impl/metadata/grpc_segment.py deleted file mode 100644 index 53ffdc72734..00000000000 --- a/chromadb/segment/impl/metadata/grpc_segment.py +++ /dev/null @@ -1,394 +0,0 @@ -from typing import Dict, List, Optional, Sequence -from chromadb.proto.convert import to_proto_request_version_context -from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor -from chromadb.segment import MetadataReader -from chromadb.config import System -from chromadb.errors import InvalidArgumentError, VersionMismatchError -from chromadb.types import Segment, RequestVersionContext -from overrides import override -from chromadb.telemetry.opentelemetry import ( - OpenTelemetryGranularity, - trace_method, -) -from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor -from chromadb.types import ( - Where, - WhereDocument, - MetadataEmbeddingRecord, -) -from chromadb.proto.chroma_pb2_grpc import MetadataReaderStub -import chromadb.proto.chroma_pb2 as pb -import grpc - - -class GrpcMetadataSegment(MetadataReader): - """Embedding Metadata segment interface""" - - _request_timeout_seconds: int - _metadata_reader_stub: MetadataReaderStub - _segment: Segment - - def __init__(self, system: System, segment: Segment) -> None: - super().__init__(system, segment) # type: ignore[safe-super] - if not segment["metadata"] or not segment["metadata"]["grpc_url"]: - raise Exception("Missing grpc_url in segment metadata") - - self._segment = segment - self._request_timeout_seconds = system.settings.require( - "chroma_query_request_timeout_seconds" - ) - - @override - def start(self) -> None: - if not self._segment["metadata"] or not self._segment["metadata"]["grpc_url"]: - raise Exception("Missing grpc_url in segment metadata") - - channel = grpc.insecure_channel(self._segment["metadata"]["grpc_url"]) - interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()] - channel = grpc.intercept_channel(channel, *interceptors) - self._metadata_reader_stub = MetadataReaderStub(channel) # type: ignore - - @override - def count(self, request_version_context: RequestVersionContext) -> int: - request: pb.CountRecordsRequest = pb.CountRecordsRequest( - segment_id=self._segment["id"].hex, - collection_id=self._segment["collection"].hex, - version_context=to_proto_request_version_context(request_version_context), - ) - - try: - response: pb.CountRecordsResponse = self._metadata_reader_stub.CountRecords( - request, - timeout=self._request_timeout_seconds, - ) - except grpc.RpcError as rpc_error: - message = rpc_error.details() - if "Collection version mismatch" in message: - raise VersionMismatchError() - raise rpc_error - - return response.count - - @override - def delete(self, where: Optional[Where] = None) -> None: - raise NotImplementedError() - - @override - def max_seqid(self) -> int: - raise NotImplementedError() - - @trace_method( - "GrpcMetadataSegment.get_metadata", - OpenTelemetryGranularity.ALL, - ) - @override - def get_metadata( - self, - request_version_context: RequestVersionContext, - where: Optional[Where] = None, - where_document: Optional[WhereDocument] = None, - ids: Optional[Sequence[str]] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, - include_metadata: bool = True, - ) -> Sequence[MetadataEmbeddingRecord]: - """Query for embedding metadata.""" - - if limit is not None and limit < 0: - raise InvalidArgumentError(f"Limit cannot be negative: {limit}") - - if offset is not None and offset < 0: - raise InvalidArgumentError(f"Offset cannot be negative: {offset}") - - request: pb.QueryMetadataRequest = pb.QueryMetadataRequest( - segment_id=self._segment["id"].hex, - collection_id=self._segment["collection"].hex, - where=self._where_to_proto(where) - if where is not None and len(where) > 0 - else None, - where_document=( - self._where_document_to_proto(where_document) - if where_document is not None and len(where_document) > 0 - else None - ), - ids=pb.UserIds(ids=ids) if ids is not None else None, - limit=limit, - offset=offset, - include_metadata=include_metadata, - version_context=to_proto_request_version_context(request_version_context), - ) - - try: - response: pb.QueryMetadataResponse = ( - self._metadata_reader_stub.QueryMetadata( - request, - timeout=self._request_timeout_seconds, - ) - ) - except grpc.RpcError as rpc_error: - message = rpc_error.details() - if "Collection version mismatch" in message: - raise VersionMismatchError() - raise rpc_error - - results: List[MetadataEmbeddingRecord] = [] - for record in response.records: - result = self._from_proto(record) - results.append(result) - - return results - - def _where_to_proto(self, where: Optional[Where]) -> pb.Where: - response = pb.Where() - if where is None: - return response - if len(where) != 1: - raise ValueError( - f"Expected where to have exactly one operator, got {where}" - ) - - for key, value in where.items(): - if not isinstance(key, str): - raise ValueError(f"Expected where key to be a str, got {key}") - - if key == "$and" or key == "$or": - if not isinstance(value, list): - raise ValueError( - f"Expected where value for $and or $or to be a list of where expressions, got {value}" - ) - children: pb.WhereChildren = pb.WhereChildren( - children=[self._where_to_proto(w) for w in value] - ) - if key == "$and": - children.operator = pb.BooleanOperator.AND - else: - children.operator = pb.BooleanOperator.OR - - response.children.CopyFrom(children) - return response - - # At this point we know we're at a direct comparison. It can either - # be of the form {"key": "value"} or {"key": {"$operator": "value"}}. - - dc = pb.DirectComparison() - dc.key = key - - if not isinstance(value, dict): - # {'key': 'value'} case - if type(value) is str: - ssc = pb.SingleStringComparison() - ssc.value = value - ssc.comparator = pb.GenericComparator.EQ - dc.single_string_operand.CopyFrom(ssc) - elif type(value) is bool: - sbc = pb.SingleBoolComparison() - sbc.value = value - sbc.comparator = pb.GenericComparator.EQ - dc.single_bool_operand.CopyFrom(sbc) - elif type(value) is int: - sic = pb.SingleIntComparison() - sic.value = value - sic.generic_comparator = pb.GenericComparator.EQ - dc.single_int_operand.CopyFrom(sic) - elif type(value) is float: - sdc = pb.SingleDoubleComparison() - sdc.value = value - sdc.generic_comparator = pb.GenericComparator.EQ - dc.single_double_operand.CopyFrom(sdc) - else: - raise ValueError( - f"Expected where value to be a string, int, or float, got {value}" - ) - else: - for operator, operand in value.items(): - if operator in ["$in", "$nin"]: - if not isinstance(operand, list): - raise ValueError( - f"Expected where value for $in or $nin to be a list of values, got {value}" - ) - if len(operand) == 0 or not all( - isinstance(x, type(operand[0])) for x in operand - ): - raise ValueError( - f"Expected where operand value to be a non-empty list, and all values to be of the same type " - f"got {operand}" - ) - list_operator = None - if operator == "$in": - list_operator = pb.ListOperator.IN - else: - list_operator = pb.ListOperator.NIN - if type(operand[0]) is str: - slo = pb.StringListComparison() - for x in operand: - slo.values.extend([x]) # type: ignore - slo.list_operator = list_operator - dc.string_list_operand.CopyFrom(slo) - elif type(operand[0]) is bool: - blo = pb.BoolListComparison() - for x in operand: - blo.values.extend([x]) # type: ignore - blo.list_operator = list_operator - dc.bool_list_operand.CopyFrom(blo) - elif type(operand[0]) is int: - ilo = pb.IntListComparison() - for x in operand: - ilo.values.extend([x]) # type: ignore - ilo.list_operator = list_operator - dc.int_list_operand.CopyFrom(ilo) - elif type(operand[0]) is float: - dlo = pb.DoubleListComparison() - for x in operand: - dlo.values.extend([x]) # type: ignore - dlo.list_operator = list_operator - dc.double_list_operand.CopyFrom(dlo) - else: - raise ValueError( - f"Expected where operand value to be a list of strings, ints, or floats, got {operand}" - ) - elif operator in ["$eq", "$ne", "$gt", "$lt", "$gte", "$lte"]: - # Direct comparison to a single value. - if type(operand) is str: - ssc = pb.SingleStringComparison() - ssc.value = operand - if operator == "$eq": - ssc.comparator = pb.GenericComparator.EQ - elif operator == "$ne": - ssc.comparator = pb.GenericComparator.NE - else: - raise ValueError( - f"Expected where operator to be $eq or $ne, got {operator}" - ) - dc.single_string_operand.CopyFrom(ssc) - elif type(operand) is bool: - sbc = pb.SingleBoolComparison() - sbc.value = operand - if operator == "$eq": - sbc.comparator = pb.GenericComparator.EQ - elif operator == "$ne": - sbc.comparator = pb.GenericComparator.NE - else: - raise ValueError( - f"Expected where operator to be $eq or $ne, got {operator}" - ) - dc.single_bool_operand.CopyFrom(sbc) - elif type(operand) is int: - sic = pb.SingleIntComparison() - sic.value = operand - if operator == "$eq": - sic.generic_comparator = pb.GenericComparator.EQ - elif operator == "$ne": - sic.generic_comparator = pb.GenericComparator.NE - elif operator == "$gt": - sic.number_comparator = pb.NumberComparator.GT - elif operator == "$lt": - sic.number_comparator = pb.NumberComparator.LT - elif operator == "$gte": - sic.number_comparator = pb.NumberComparator.GTE - elif operator == "$lte": - sic.number_comparator = pb.NumberComparator.LTE - else: - raise ValueError( - f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}" - ) - dc.single_int_operand.CopyFrom(sic) - elif type(operand) is float: - sfc = pb.SingleDoubleComparison() - sfc.value = operand - if operator == "$eq": - sfc.generic_comparator = pb.GenericComparator.EQ - elif operator == "$ne": - sfc.generic_comparator = pb.GenericComparator.NE - elif operator == "$gt": - sfc.number_comparator = pb.NumberComparator.GT - elif operator == "$lt": - sfc.number_comparator = pb.NumberComparator.LT - elif operator == "$gte": - sfc.number_comparator = pb.NumberComparator.GTE - elif operator == "$lte": - sfc.number_comparator = pb.NumberComparator.LTE - else: - raise ValueError( - f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}" - ) - dc.single_double_operand.CopyFrom(sfc) - else: - raise ValueError( - f"Expected where operand value to be a string, int, or float, got {operand}" - ) - else: - # This case should never happen, as we've already - # handled the case for direct comparisons. - pass - - response.direct_comparison.CopyFrom(dc) - return response - - def _where_document_to_proto( - self, where_document: Optional[WhereDocument] - ) -> pb.WhereDocument: - response = pb.WhereDocument() - if where_document is None: - return response - if len(where_document) != 1: - raise ValueError( - f"Expected where_document to have exactly one operator, got {where_document}" - ) - - for operator, operand in where_document.items(): - if operator == "$and" or operator == "$or": - # Nested "$and" or "$or" expression. - if not isinstance(operand, list): - raise ValueError( - f"Expected where_document value for $and or $or to be a list of where_document expressions, got {operand}" - ) - children: pb.WhereDocumentChildren = pb.WhereDocumentChildren( - children=[self._where_document_to_proto(w) for w in operand] - ) - if operator == "$and": - children.operator = pb.BooleanOperator.AND - else: - children.operator = pb.BooleanOperator.OR - - response.children.CopyFrom(children) - else: - # Direct "$contains" or "$not_contains" comparison to a single - # value. - if not isinstance(operand, str): - raise ValueError( - f"Expected where_document operand to be a string, got {operand}" - ) - dwd = pb.DirectWhereDocument() - dwd.document = operand - if operator == "$contains": - dwd.operator = pb.WhereDocumentOperator.CONTAINS - elif operator == "$not_contains": - dwd.operator = pb.WhereDocumentOperator.NOT_CONTAINS - else: - raise ValueError( - f"Expected where_document operator to be one of $contains, $not_contains, got {operator}" - ) - response.direct.CopyFrom(dwd) - - return response - - def _from_proto( - self, record: pb.MetadataEmbeddingRecord - ) -> MetadataEmbeddingRecord: - translated_metadata: Dict[str, str | int | float | bool] = {} - record_metadata_map = record.metadata.metadata - for key, value in record_metadata_map.items(): - if value.HasField("bool_value"): - translated_metadata[key] = value.bool_value - elif value.HasField("string_value"): - translated_metadata[key] = value.string_value - elif value.HasField("int_value"): - translated_metadata[key] = value.int_value - elif value.HasField("float_value"): - translated_metadata[key] = value.float_value - else: - raise ValueError(f"Unknown metadata value type: {value}") - - mer = MetadataEmbeddingRecord(id=record.id, metadata=translated_metadata) - - return mer diff --git a/chromadb/segment/impl/vector/grpc_segment.py b/chromadb/segment/impl/vector/grpc_segment.py deleted file mode 100644 index a66de4a71cd..00000000000 --- a/chromadb/segment/impl/vector/grpc_segment.py +++ /dev/null @@ -1,145 +0,0 @@ -from overrides import EnforceOverrides, override -from typing import List, Optional, Sequence -from chromadb.config import System -from chromadb.proto.convert import ( - from_proto_vector_embedding_record, - from_proto_vector_query_result, - to_proto_request_version_context, - to_proto_vector, -) -from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor -from chromadb.segment import VectorReader -from chromadb.errors import VersionMismatchError -from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams -from chromadb.telemetry.opentelemetry import ( - OpenTelemetryGranularity, - trace_method, -) -from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor -from chromadb.types import ( - Metadata, - RequestVersionContext, - ScalarEncoding, - Segment, - VectorEmbeddingRecord, - VectorQuery, - VectorQueryResult, -) -from chromadb.proto.chroma_pb2_grpc import VectorReaderStub -from chromadb.proto.chroma_pb2 import ( - GetVectorsRequest, - GetVectorsResponse, - QueryVectorsRequest, - QueryVectorsResponse, -) -import grpc - - -class GrpcVectorSegment(VectorReader, EnforceOverrides): - _vector_reader_stub: VectorReaderStub - _segment: Segment - _request_timeout_seconds: int - - def __init__(self, system: System, segment: Segment): - # TODO: move to start() method - # TODO: close channel in stop() method - if segment["metadata"] is None or segment["metadata"]["grpc_url"] is None: - raise Exception("Missing grpc_url in segment metadata") - - channel = grpc.insecure_channel(segment["metadata"]["grpc_url"]) - interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()] - channel = grpc.intercept_channel(channel, *interceptors) - self._vector_reader_stub = VectorReaderStub(channel) # type: ignore - self._segment = segment - self._request_timeout_seconds = system.settings.require( - "chroma_query_request_timeout_seconds" - ) - - @trace_method("GrpcVectorSegment.get_vectors", OpenTelemetryGranularity.ALL) - @override - def get_vectors( - self, - request_version_context: RequestVersionContext, - ids: Optional[Sequence[str]] = None, - ) -> Sequence[VectorEmbeddingRecord]: - request = GetVectorsRequest( - ids=ids, - segment_id=self._segment["id"].hex, - collection_id=self._segment["collection"].hex, - version_context=to_proto_request_version_context(request_version_context), - ) - - try: - response: GetVectorsResponse = self._vector_reader_stub.GetVectors( - request, - timeout=self._request_timeout_seconds, - ) - except grpc.RpcError as rpc_error: - message = rpc_error.details() - if "Collection version mismatch" in message: - raise VersionMismatchError() - raise rpc_error - - results: List[VectorEmbeddingRecord] = [] - for vector in response.records: - result = from_proto_vector_embedding_record(vector) - results.append(result) - return results - - @trace_method("GrpcVectorSegment.query_vectors", OpenTelemetryGranularity.ALL) - @override - def query_vectors( - self, query: VectorQuery - ) -> Sequence[Sequence[VectorQueryResult]]: - request = QueryVectorsRequest( - vectors=[ - to_proto_vector(vector=v, encoding=ScalarEncoding.FLOAT32) - for v in query["vectors"] - ], - k=query["k"], - allowed_ids=query["allowed_ids"], - include_embeddings=query["include_embeddings"], - segment_id=self._segment["id"].hex, - collection_id=self._segment["collection"].hex, - version_context=to_proto_request_version_context( - query["request_version_context"] - ), - ) - - try: - response: QueryVectorsResponse = self._vector_reader_stub.QueryVectors( - request, - timeout=self._request_timeout_seconds, - ) - except grpc.RpcError as rpc_error: - message = rpc_error.details() - if "Collection version mismatch" in message: - raise VersionMismatchError() - raise rpc_error - - results: List[List[VectorQueryResult]] = [] - for result in response.results: - curr_result: List[VectorQueryResult] = [] - for r in result.results: - curr_result.append(from_proto_vector_query_result(r)) - results.append(curr_result) - return results - - @override - def count(self, request_version_context: RequestVersionContext) -> int: - raise NotImplementedError() - - @override - def max_seqid(self) -> int: - return 0 - - @staticmethod - @override - def propagate_collection_metadata(metadata: Metadata) -> Optional[Metadata]: - # Great example of why language sharing is nice. - segment_metadata = PersistentHnswParams.extract(metadata) - return segment_metadata - - @override - def delete(self) -> None: - raise NotImplementedError() diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index f3d9909718b..4f8aeca38a8 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -15,9 +15,9 @@ to_thread, CapacityLimiter, ) -from fastapi import FastAPI as _FastAPI, Response, Request, Body -from fastapi.responses import JSONResponse, ORJSONResponse +from fastapi import FastAPI as _FastAPI, Response, Request from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import ORJSONResponse from fastapi.routing import APIRoute from fastapi import HTTPException, status @@ -108,18 +108,18 @@ async def catch_exceptions_middleware( except ChromaError as e: return fastapi_json_response(e) except ValueError as e: - return JSONResponse( + return ORJSONResponse( content={"error": "InvalidArgumentError", "message": str(e)}, status_code=400, ) except TypeError as e: - return JSONResponse( + return ORJSONResponse( content={"error": "InvalidArgumentError", "message": str(e)}, status_code=400, ) except Exception as e: logger.exception(e) - return JSONResponse(content={"error": repr(e)}, status_code=500) + return ORJSONResponse(content={"error": repr(e)}, status_code=500) async def check_http_version_middleware( @@ -142,6 +142,18 @@ def validate_model(model: Type[D], data: Any) -> D: # type: ignore return model.parse_obj(data) # pydantic 1.x +def get_openapi_extras_for_model(request_model: Type[D]) -> Dict[str, Any]: + openapi_extra = { + "requestBody": { + "content": { + "application/json": {"schema": request_model.model_json_schema()} + }, + "required": True, + } + } + return openapi_extra + + class ChromaAPIRouter(fastapi.APIRouter): # type: ignore # A simple subclass of fastapi's APIRouter which treats URLs with a # trailing "/" the same as URLs without. Docs will only contain URLs @@ -241,6 +253,7 @@ def setup_v2_routes(self) -> None: self.create_database, methods=["POST"], response_model=None, + openapi_extra=get_openapi_extras_for_model(CreateDatabase), ) self.router.add_api_route( @@ -255,6 +268,7 @@ def setup_v2_routes(self) -> None: self.create_tenant, methods=["POST"], response_model=None, + openapi_extra=get_openapi_extras_for_model(CreateTenant), ) self.router.add_api_route( @@ -281,6 +295,7 @@ def setup_v2_routes(self) -> None: self.create_collection, methods=["POST"], response_model=None, + openapi_extra=get_openapi_extras_for_model(CreateCollection), ) self.router.add_api_route( @@ -289,30 +304,35 @@ def setup_v2_routes(self) -> None: methods=["POST"], status_code=status.HTTP_201_CREATED, response_model=None, + openapi_extra=get_openapi_extras_for_model(AddEmbedding), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/update", self.update, methods=["POST"], response_model=None, + openapi_extra=get_openapi_extras_for_model(UpdateEmbedding), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/upsert", self.upsert, methods=["POST"], response_model=None, + openapi_extra=get_openapi_extras_for_model(AddEmbedding), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/get", self.get, methods=["POST"], response_model=None, + openapi_extra=get_openapi_extras_for_model(GetEmbedding), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/delete", self.delete, methods=["POST"], response_model=None, + openapi_extra=get_openapi_extras_for_model(DeleteEmbedding), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/count", @@ -325,6 +345,7 @@ def setup_v2_routes(self) -> None: self.get_nearest_neighbors, methods=["POST"], response_model=None, + openapi_extra=get_openapi_extras_for_model(request_model=QueryEmbedding), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_name}", @@ -337,6 +358,7 @@ def setup_v2_routes(self) -> None: self.update_collection, methods=["PUT"], response_model=None, + openapi_extra=get_openapi_extras_for_model(UpdateCollection), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_name}", @@ -353,8 +375,8 @@ def app(self) -> fastapi.FastAPI: async def rate_limit_exception_handler( self, request: Request, exc: RateLimitError - ) -> JSONResponse: - return JSONResponse( + ) -> ORJSONResponse: + return ORJSONResponse( status_code=429, content={"message": "Rate limit exceeded."}, ) @@ -364,8 +386,8 @@ def root(self) -> Dict[str, int]: async def quota_exception_handler( self, request: Request, exc: QuotaError - ) -> JSONResponse: - return JSONResponse( + ) -> ORJSONResponse: + return ORJSONResponse( status_code=400, content={"message": exc.message()}, ) @@ -452,7 +474,6 @@ async def create_database( self, request: Request, tenant: str, - body: CreateDatabase = Body(...), ) -> None: def process_create_database( tenant: str, headers: Headers, raw_body: bytes @@ -506,7 +527,8 @@ async def get_database( @trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION) async def create_tenant( - self, request: Request, body: CreateTenant = Body(...) + self, + request: Request, ) -> None: def process_create_tenant(request: Request, raw_body: bytes) -> None: tenant = validate_model(CreateTenant, orjson.loads(raw_body)) @@ -625,7 +647,6 @@ async def create_collection( request: Request, tenant: str, database_name: str, - body: CreateCollection = Body(...), ) -> CollectionModel: def process_create_collection( request: Request, tenant: str, database: str, raw_body: bytes @@ -708,7 +729,6 @@ async def update_collection( database_name: str, collection_id: str, request: Request, - body: UpdateCollection = Body(...), ) -> None: def process_update_collection( request: Request, collection_id: str, raw_body: bytes @@ -771,7 +791,6 @@ async def add( tenant: str, database_name: str, collection_id: str, - body: AddEmbedding = Body(...), ) -> bool: try: @@ -821,7 +840,6 @@ async def update( tenant: str, database_name: str, collection_id: str, - body: UpdateEmbedding = Body(...), ) -> None: def process_update(request: Request, raw_body: bytes) -> bool: update = validate_model(UpdateEmbedding, orjson.loads(raw_body)) @@ -863,7 +881,6 @@ async def upsert( tenant: str, database_name: str, collection_id: str, - body: AddEmbedding = Body(...), ) -> None: def process_upsert(request: Request, raw_body: bytes) -> bool: upsert = validate_model(AddEmbedding, orjson.loads(raw_body)) @@ -908,7 +925,6 @@ async def get( tenant: str, database_name: str, request: Request, - body: GetEmbedding = Body(...), ) -> GetResult: def process_get(request: Request, raw_body: bytes) -> GetResult: get = validate_model(GetEmbedding, orjson.loads(raw_body)) @@ -959,7 +975,6 @@ async def delete( tenant: str, database_name: str, request: Request, - body: DeleteEmbedding = Body(...), ) -> None: def process_delete(request: Request, raw_body: bytes) -> None: delete = validate_model(DeleteEmbedding, orjson.loads(raw_body)) @@ -1044,7 +1059,6 @@ async def get_nearest_neighbors( database_name: str, collection_id: str, request: Request, - body: QueryEmbedding = Body(...), ) -> QueryResult: def process_query(request: Request, raw_body: bytes) -> QueryResult: query = validate_model(QueryEmbedding, orjson.loads(raw_body)) @@ -1124,6 +1138,7 @@ def setup_v1_routes(self) -> None: self.create_database_v1, methods=["POST"], response_model=None, + openapi_extra=get_openapi_extras_for_model(CreateDatabase), ) self.router.add_api_route( @@ -1138,6 +1153,7 @@ def setup_v1_routes(self) -> None: self.create_tenant_v1, methods=["POST"], response_model=None, + openapi_extra=get_openapi_extras_for_model(CreateTenant), ) self.router.add_api_route( @@ -1164,6 +1180,7 @@ def setup_v1_routes(self) -> None: self.create_collection_v1, methods=["POST"], response_model=None, + openapi_extra=get_openapi_extras_for_model(CreateCollection), ) self.router.add_api_route( @@ -1172,30 +1189,35 @@ def setup_v1_routes(self) -> None: methods=["POST"], status_code=status.HTTP_201_CREATED, response_model=None, + openapi_extra=get_openapi_extras_for_model(AddEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_id}/update", self.update_v1, methods=["POST"], response_model=None, + openapi_extra=get_openapi_extras_for_model(UpdateEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_id}/upsert", self.upsert_v1, methods=["POST"], response_model=None, + openapi_extra=get_openapi_extras_for_model(AddEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_id}/get", self.get_v1, methods=["POST"], response_model=None, + openapi_extra=get_openapi_extras_for_model(GetEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_id}/delete", self.delete_v1, methods=["POST"], response_model=None, + openapi_extra=get_openapi_extras_for_model(DeleteEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_id}/count", @@ -1208,6 +1230,7 @@ def setup_v1_routes(self) -> None: self.get_nearest_neighbors_v1, methods=["POST"], response_model=None, + openapi_extra=get_openapi_extras_for_model(QueryEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_name}", @@ -1220,6 +1243,7 @@ def setup_v1_routes(self) -> None: self.update_collection_v1, methods=["PUT"], response_model=None, + openapi_extra=get_openapi_extras_for_model(UpdateCollection), ) self.router.add_api_route( "/api/v1/collections/{collection_name}", @@ -1300,7 +1324,6 @@ async def create_database_v1( self, request: Request, tenant: str = DEFAULT_TENANT, - body: CreateDatabase = Body(...), ) -> None: def process_create_database( tenant: str, headers: Headers, raw_body: bytes @@ -1366,7 +1389,8 @@ async def get_database_v1( @trace_method("FastAPI.create_tenant_v1", OpenTelemetryGranularity.OPERATION) async def create_tenant_v1( - self, request: Request, body: CreateTenant = Body(...) + self, + request: Request, ) -> None: def process_create_tenant(request: Request, raw_body: bytes) -> None: tenant = validate_model(CreateTenant, orjson.loads(raw_body)) @@ -1491,7 +1515,6 @@ async def create_collection_v1( request: Request, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, - body: CreateCollection = Body(...), ) -> CollectionModel: def process_create_collection( request: Request, tenant: str, database: str, raw_body: bytes @@ -1563,21 +1586,25 @@ async def get_collection_v1( if maybe_database: database = maybe_database - api_collection_model = cast( - CollectionModel, - await to_thread.run_sync( - self._api.get_collection, - collection_name, - tenant, - database, - limiter=self._capacity_limiter, - ), - ) - return api_collection_model + async def inner(): + api_collection_model = cast( + CollectionModel, + await to_thread.run_sync( + self._api.get_collection, + collection_name, + tenant, + database, + limiter=self._capacity_limiter, + ), + ) + return api_collection_model + return await inner() @trace_method("FastAPI.update_collection_v1", OpenTelemetryGranularity.OPERATION) async def update_collection_v1( - self, collection_id: str, request: Request, body: UpdateCollection = Body(...) + self, + collection_id: str, + request: Request, ) -> None: def process_update_collection( request: Request, collection_id: str, raw_body: bytes @@ -1637,7 +1664,9 @@ async def delete_collection_v1( @trace_method("FastAPI.add_v1", OpenTelemetryGranularity.OPERATION) async def add_v1( - self, request: Request, collection_id: str, body: AddEmbedding = Body(...) + self, + request: Request, + collection_id: str, ) -> bool: try: @@ -1678,7 +1707,9 @@ def process_add(request: Request, raw_body: bytes) -> bool: @trace_method("FastAPI.update_v1", OpenTelemetryGranularity.OPERATION) async def update_v1( - self, request: Request, collection_id: str, body: UpdateEmbedding = Body(...) + self, + request: Request, + collection_id: str, ) -> None: def process_update(request: Request, raw_body: bytes) -> bool: update = validate_model(UpdateEmbedding, orjson.loads(raw_body)) @@ -1711,7 +1742,9 @@ def process_update(request: Request, raw_body: bytes) -> bool: @trace_method("FastAPI.upsert_v1", OpenTelemetryGranularity.OPERATION) async def upsert_v1( - self, request: Request, collection_id: str, body: AddEmbedding = Body(...) + self, + request: Request, + collection_id: str, ) -> None: def process_upsert(request: Request, raw_body: bytes) -> bool: upsert = validate_model(AddEmbedding, orjson.loads(raw_body)) @@ -1747,7 +1780,9 @@ def process_upsert(request: Request, raw_body: bytes) -> bool: @trace_method("FastAPI.get_v1", OpenTelemetryGranularity.OPERATION) async def get_v1( - self, collection_id: str, request: Request, body: GetEmbedding = Body(...) + self, + collection_id: str, + request: Request, ) -> GetResult: def process_get(request: Request, raw_body: bytes) -> GetResult: get = validate_model(GetEmbedding, orjson.loads(raw_body)) @@ -1789,7 +1824,9 @@ def process_get(request: Request, raw_body: bytes) -> GetResult: @trace_method("FastAPI.delete_v1", OpenTelemetryGranularity.OPERATION) async def delete_v1( - self, collection_id: str, request: Request, body: DeleteEmbedding = Body(...) + self, + collection_id: str, + request: Request, ) -> None: def process_delete(request: Request, raw_body: bytes) -> None: delete = validate_model(DeleteEmbedding, orjson.loads(raw_body)) @@ -1865,7 +1902,6 @@ async def get_nearest_neighbors_v1( self, collection_id: str, request: Request, - body: QueryEmbedding = Body(...), ) -> QueryResult: def process_query(request: Request, raw_body: bytes) -> QueryResult: query = validate_model(QueryEmbedding, orjson.loads(raw_body)) diff --git a/chromadb/test/client/test_http_client_v1_compatability.py b/chromadb/test/client/test_http_client_v1_compatability.py index 983770ee7ad..f282aa6361b 100644 --- a/chromadb/test/client/test_http_client_v1_compatability.py +++ b/chromadb/test/client/test_http_client_v1_compatability.py @@ -42,15 +42,7 @@ def test_http_client_bw_compatibility() -> None: old_version = "0.5.11" # Module with known v1 client - # Version <3.9 requires bounding tokenizers<=0.20.3 - # TOOD(hammadb): This code is duplicated in test_cross_version_persist.py - # for expediency on 11/27/2024 I am copy pasting rather than refactoring - # to DRY. Refactor later. - (major, minor, _) = pysys.version_info[:3] - if major == 3 and minor < 9: - install_version(old_version, {"tokenizers": "<=0.20.3"}) - else: - install_version(old_version, {}) + install_version(old_version, {}) ctx = multiprocessing.get_context("spawn") conn1, conn2 = multiprocessing.Pipe() diff --git a/chromadb/test/db/test_system.py b/chromadb/test/db/test_system.py index af90bf876b8..20384e739c1 100644 --- a/chromadb/test/db/test_system.py +++ b/chromadb/test/db/test_system.py @@ -174,6 +174,7 @@ def sample_segment(collection_id: uuid.UUID = uuid.uuid4(), scope=scope, collection=collection_id, metadata=metadata, + file_paths={}, ) # region Collection tests @@ -184,13 +185,17 @@ def test_create_get_delete_collections(sysdb: SysDB) -> None: segments_created_with_collection = [] for collection in sample_collections: logger.debug(f"Creating collection: {collection.name}") - segment = sample_segment(collection_id=collection.id) - segments_created_with_collection.append(segment) + segments = [ + sample_segment(collection_id=collection.id, scope=SegmentScope.METADATA), + sample_segment(collection_id=collection.id, scope=SegmentScope.RECORD), + sample_segment(collection_id=collection.id, scope=SegmentScope.VECTOR), + ] + segments_created_with_collection.extend(segments) sysdb.create_collection( id=collection.id, name=collection.name, configuration=collection.get_configuration(), - segments=[segment], + segments=segments, metadata=collection["metadata"], dimension=collection["dimension"], ) @@ -222,6 +227,12 @@ def test_create_get_delete_collections(sysdb: SysDB) -> None: result = sysdb.get_collections(id=collection["id"]) assert result == [collection] + # Verify segment information + for collection in sample_collections: + collection_with_segments_result = sysdb.get_collection_with_segments(collection.id) + assert collection_with_segments_result["collection"] == collection + assert all([segment["collection"] == collection.id for segment in collection_with_segments_result["segments"]]) + # Delete c1 = sample_collections[0] sysdb.delete_collection(id=c1.id) @@ -287,6 +298,7 @@ def test_update_collections(sysdb: SysDB) -> None: scope=SegmentScope.VECTOR, collection=coll.id, metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, + file_paths={}, ) ], metadata=coll["metadata"], @@ -335,6 +347,7 @@ def test_get_or_create_collection(sysdb: SysDB) -> None: scope=SegmentScope.VECTOR, collection=collection.id, metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, + file_paths={}, ) ], metadata=collection["metadata"], @@ -355,6 +368,7 @@ def test_get_or_create_collection(sysdb: SysDB) -> None: scope=SegmentScope.VECTOR, collection=sample_collections[1].id, metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, + file_paths={}, ) ], # This could have been empty - []. metadata=collection["metadata"], @@ -377,6 +391,7 @@ def test_get_or_create_collection(sysdb: SysDB) -> None: scope=SegmentScope.VECTOR, collection=sample_collections[1].id, metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, + file_paths={}, ) ], get_or_create=True, @@ -396,6 +411,7 @@ def test_get_or_create_collection(sysdb: SysDB) -> None: scope=SegmentScope.VECTOR, collection=sample_collections[2].id, metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, + file_paths={}, ) ], get_or_create=False, @@ -417,6 +433,7 @@ def test_get_or_create_collection(sysdb: SysDB) -> None: scope=SegmentScope.VECTOR, collection=sample_collections[2].id, metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, + file_paths={}, ) ], metadata=collection["metadata"], @@ -439,6 +456,7 @@ def test_get_or_create_collection(sysdb: SysDB) -> None: scope=SegmentScope.VECTOR, collection=sample_collections[2].id, metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, + file_paths={}, ) ], get_or_create=True, @@ -765,6 +783,7 @@ def test_get_database_with_tenants(sysdb: SysDB) -> None: scope=SegmentScope.VECTOR, collection=sample_collections[0].id, metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, + file_paths={}, ), Segment( id=uuid.UUID("11111111-d7d7-413b-92e1-731098a6e492"), @@ -772,6 +791,7 @@ def test_get_database_with_tenants(sysdb: SysDB) -> None: scope=SegmentScope.VECTOR, collection=sample_collections[1].id, metadata={"test_str": "str2", "test_int": 2, "test_float": 2.3}, + file_paths={}, ), ] @@ -860,6 +880,7 @@ def test_update_segment(sysdb: SysDB) -> None: scope=SegmentScope.VECTOR, collection=sample_collections[0].id, metadata=metadata, + file_paths={}, ) sysdb.reset_state() diff --git a/chromadb/test/distributed/test_sanity.py b/chromadb/test/distributed/test_sanity.py index 34f04759623..8ffbd140ce4 100644 --- a/chromadb/test/distributed/test_sanity.py +++ b/chromadb/test/distributed/test_sanity.py @@ -5,11 +5,14 @@ import time from chromadb.api import ClientAPI from chromadb.test.conftest import ( - COMPACTION_SLEEP, reset, skip_if_not_cluster, ) from chromadb.test.property import invariants +from chromadb.test.utils.wait_for_version_increase import ( + wait_for_version_increase, + get_collection_version, +) import numpy as np @@ -78,7 +81,7 @@ def test_add_include_all_with_compaction_delay(client: ClientAPI) -> None: documents=[documents[-1]], ) - time.sleep(COMPACTION_SLEEP) # Wait for the documents to be compacted + wait_for_version_increase(client, collection.name, get_collection_version(client, collection.name), 120) random_query_1 = np.random.rand(1, 3)[0] random_query_2 = np.random.rand(1, 3)[0] diff --git a/chromadb/test/distributed/test_version_mismatch.py b/chromadb/test/distributed/test_version_mismatch.py deleted file mode 100644 index 817abf5c0e4..00000000000 --- a/chromadb/test/distributed/test_version_mismatch.py +++ /dev/null @@ -1,219 +0,0 @@ -import random -from typing import List, Tuple -import uuid -from chromadb.api.models.Collection import Collection -from chromadb.config import Settings, System -from chromadb.db.impl.grpc.client import GrpcSysDB -from chromadb.db.system import SysDB -from chromadb.errors import VersionMismatchError -from chromadb.segment import MetadataReader, VectorReader -from chromadb.segment.impl.metadata.grpc_segment import GrpcMetadataSegment -from chromadb.segment.impl.vector.grpc_segment import GrpcVectorSegment -from chromadb.test.conftest import reset, skip_if_not_cluster -from chromadb.api import ClientAPI -from chromadb.test.utils.wait_for_version_increase import wait_for_version_increase -from chromadb.types import RequestVersionContext, SegmentScope, VectorQuery - - -# Helpers -def create_test_collection(client: ClientAPI, name: str) -> Collection: - return client.create_collection( - name=name, - metadata={"hnsw:construction_ef": 128, "hnsw:search_ef": 128, "hnsw:M": 128}, - ) - - -def add_random_records_and_wait_for_compaction( - client: ClientAPI, collection: Collection, n: int -) -> Tuple[List[str], List[List[float]], int]: - ids = [] - embeddings = [] - for i in range(n): - ids.append(str(i)) - embeddings.append([random.random(), random.random(), random.random()]) - collection.add( - ids=[str(i)], - embeddings=[embeddings[-1]], # type: ignore - ) - final_version = wait_for_version_increase( - client=client, collection_name=collection.name, initial_version=0 - ) - return ids, embeddings, final_version - - -def get_mock_frontend_system() -> System: - settings = Settings( - chroma_coordinator_host="localhost", chroma_server_grpc_port=50051 - ) - return System(settings) - - -def get_vector_segment( - system: System, sysdb: SysDB, collection: uuid.UUID -) -> GrpcVectorSegment: - segment = sysdb.get_segments(collection=collection, scope=SegmentScope.VECTOR)[0] - if segment["metadata"] is None: - segment["metadata"] = {} - # Inject the url, replicating the behavior of the segment manager, we use the tilt grpc server url - segment["metadata"]["grpc_url"] = "localhost:50053" # type: ignore - ret_segment = GrpcVectorSegment(system, segment) - ret_segment.start() - return ret_segment - - -def get_metadata_segment( - system: System, sysdb: SysDB, collection: uuid.UUID -) -> GrpcMetadataSegment: - segment = sysdb.get_segments(collection=collection, scope=SegmentScope.METADATA)[0] - if segment["metadata"] is None: - segment["metadata"] = {} - # Inject the url, replicating the behavior of the segment manager, we use the tilt grpc server url - segment["metadata"]["grpc_url"] = "localhost:50053" # type: ignore - ret_segment = GrpcMetadataSegment(system, segment) - ret_segment.start() - return ret_segment - - -def setup_vector_test( - client: ClientAPI, n: int -) -> Tuple[VectorReader, List[str], List[List[float]], int, int]: - reset(client) - collection = create_test_collection(client=client, name="test_version_mismatch") - ids, embeddings, version = add_random_records_and_wait_for_compaction( - client=client, collection=collection, n=n - ) - log_position = client.get_collection(collection.name)._model.log_position - - fe_system = get_mock_frontend_system() - sysdb = GrpcSysDB(fe_system) - sysdb.start() - - return ( - get_vector_segment(system=fe_system, sysdb=sysdb, collection=collection.id), - ids, - embeddings, - version, - log_position, - ) - - -def setup_metadata_test( - client: ClientAPI, n: int -) -> Tuple[MetadataReader, List[str], List[List[float]], int, int]: - reset(client) - collection = create_test_collection(client=client, name="test_version_mismatch") - ids, embeddings, version = add_random_records_and_wait_for_compaction( - client=client, collection=collection, n=n - ) - log_position = client.get_collection(collection.name)._model.log_position - - fe_system = get_mock_frontend_system() - sysdb = GrpcSysDB(fe_system) - sysdb.start() - - return ( - get_metadata_segment(system=fe_system, sysdb=sysdb, collection=collection.id), - ids, - embeddings, - version, - log_position, - ) - - -@skip_if_not_cluster() -def test_version_mistmatch_query_vectors( - client: ClientAPI, -) -> None: - N = 100 - reader, _, embeddings, compacted_version, log_position = setup_vector_test( - client=client, n=N - ) - request = VectorQuery( - vectors=[embeddings[0]], - request_version_context=RequestVersionContext( - collection_version=compacted_version, log_position=log_position - ), - k=10, - include_embeddings=False, - allowed_ids=None, - options=None, - ) - - reader.query_vectors(query=request) - # Now change the collection version to > N, which should cause a version mismatch - request["request_version_context"]["collection_version"] = N + 1 - try: - reader.query_vectors(request) - except VersionMismatchError: - pass - except Exception as e: - assert False, f"Unexpected exception {e}" - - -@skip_if_not_cluster() -def test_version_mistmatch_get_vectors( - client: ClientAPI, -) -> None: - N = 100 - reader, _, _, compacted_version, log_position = setup_vector_test( - client=client, n=N - ) - request_version_context = RequestVersionContext( - collection_version=compacted_version, log_position=log_position - ) - - reader.get_vectors(ids=None, request_version_context=request_version_context) - # Now change the collection version to > N, which should cause a version mismatch - request_version_context["collection_version"] = N + 1 - try: - reader.get_vectors(request_version_context) - except VersionMismatchError: - pass - except Exception as e: - assert False, f"Unexpected exception {e}" - - -@skip_if_not_cluster() -def test_version_mismatch_metadata_get( - client: ClientAPI, -) -> None: - N = 100 - reader, _, _, compacted_version, log_position = setup_metadata_test( - client=client, n=N - ) - request_version_context = RequestVersionContext( - collection_version=compacted_version, log_position=log_position - ) - - reader.get_metadata(request_version_context=request_version_context) - # Now change the collection version to > N, which should cause a version mismatch - request_version_context["collection_version"] = N + 1 - try: - reader.get_metadata(request_version_context) - except VersionMismatchError: - pass - except Exception as e: - assert False, f"Unexpected exception {e}" - - -@skip_if_not_cluster() -def test_version_mismatch_metadata_count( - client: ClientAPI, -) -> None: - N = 100 - reader, _, _, compacted_version, log_position = setup_metadata_test( - client=client, n=N - ) - request_version_context = RequestVersionContext( - collection_version=compacted_version, log_position=log_position - ) - - reader.count(request_version_context) - # Now change the collection version to > N, which should cause a version mismatch - request_version_context["collection_version"] = N + 1 - try: - reader.count(request_version_context) - except VersionMismatchError: - pass - except Exception as e: - assert False, f"Unexpected exception {e}" diff --git a/chromadb/test/property/test_cross_version_persist.py b/chromadb/test/property/test_cross_version_persist.py index 677426a2082..4b1d14936d8 100644 --- a/chromadb/test/property/test_cross_version_persist.py +++ b/chromadb/test/property/test_cross_version_persist.py @@ -150,13 +150,7 @@ def version_settings(request) -> Generator[Tuple[str, Settings], None, None]: configuration = request.param version = configuration[0] - # Version <3.9 requires bounding tokenizers<=0.20.3 - (major, minor, patch) = sys.version_info[:3] - if major == 3 and minor < 9: - install_version(version, {"tokenizers": "<=0.20.3"}) - else: - install_version(version, {}) - + install_version(version, {}) yield configuration # Cleanup the installed version path = get_path_to_version_install(version) diff --git a/chromadb/test/property/test_sysdb.py b/chromadb/test/property/test_sysdb.py new file mode 100644 index 00000000000..13708c21435 --- /dev/null +++ b/chromadb/test/property/test_sysdb.py @@ -0,0 +1,158 @@ +import pytest +from hypothesis.stateful import ( + Bundle, + RuleBasedStateMachine, + rule, + initialize, + multiple, + consumes, + run_state_machine_as_test, + MultipleResults, +) +from typing import Dict +from uuid import uuid4 + +import chromadb.test.property.strategies as strategies +from chromadb.api.configuration import CollectionConfigurationInternal +from chromadb.config import System +from chromadb.db.system import SysDB +from chromadb.segment import SegmentType +from chromadb.test.conftest import NOT_CLUSTER_ONLY +from chromadb.test.db.test_system import sqlite, grpc_with_real_server +from chromadb.types import Segment, SegmentScope + + +class SysDBStateMachine(RuleBasedStateMachine): + collections: Bundle[strategies.Collection] = Bundle("collections") + created_collections: Dict[str, strategies.Collection] + + def __init__(self, sysdb: SysDB): + super().__init__() + self.sysdb = sysdb + + @initialize() + def initialize(self) -> None: + self.sysdb.reset_state() + self.created_collections = {} + + @rule(target=collections, coll=strategies.collections()) + def create_collection( + self, coll: strategies.Collection + ) -> MultipleResults[strategies.Collection]: + # TODO: Convert collection views used in tests into actual Collections / Collection models + segments = ( + [ + Segment( + id=uuid4(), + type=SegmentType.SQLITE.value, + scope=SegmentScope.METADATA, + collection=coll.id, + metadata={}, + file_paths={}, + ), + Segment( + id=uuid4(), + type=SegmentType.HNSW_LOCAL_MEMORY.value, + scope=SegmentScope.VECTOR, + collection=coll.id, + metadata={}, + file_paths={}, + ), + ] + if NOT_CLUSTER_ONLY + else [ + Segment( + id=uuid4(), + type=SegmentType.BLOCKFILE_METADATA.value, + scope=SegmentScope.METADATA, + collection=coll.id, + metadata={}, + file_paths={}, + ), + Segment( + id=uuid4(), + type=SegmentType.BLOCKFILE_RECORD.value, + scope=SegmentScope.RECORD, + collection=coll.id, + metadata={}, + file_paths={}, + ), + Segment( + id=uuid4(), + type=SegmentType.HNSW_DISTRIBUTED.value, + scope=SegmentScope.VECTOR, + collection=coll.id, + metadata={}, + file_paths={}, + ), + ] + ) + if coll.name in self.created_collections: + with pytest.raises(Exception): + self.sysdb.create_collection( + coll.id, coll.name, CollectionConfigurationInternal(), segments + ) + else: + self.sysdb.create_collection( + coll.id, coll.name, CollectionConfigurationInternal(), segments + ) + self.created_collections[coll.name] = coll + return multiple(coll) + + @rule(coll=collections) + def get_collection(self, coll: strategies.Collection) -> None: + if ( + coll.name in self.created_collections + and coll.id == self.created_collections[coll.name].id + ): + fetched_collections = self.sysdb.get_collections(id=coll.id) + assert len(fetched_collections) == 1 + assert fetched_collections[0].name == coll.name + else: + assert len(self.sysdb.get_collections(id=coll.id)) == 0 + + @rule(coll=collections) + def get_collection_with_segments(self, coll: strategies.Collection) -> None: + if ( + coll.name in self.created_collections + and coll.id == self.created_collections[coll.name].id + ): + fetched_collection_and_segments = self.sysdb.get_collection_with_segments( + collection_id=coll.id + ) + assert fetched_collection_and_segments["collection"].name == coll.name + scopes = [] + for segment in fetched_collection_and_segments["segments"]: + assert segment["collection"] == coll.id + scopes.append(segment["scope"]) + if NOT_CLUSTER_ONLY: + assert len(scopes) == 2 + assert set(scopes) == {SegmentScope.METADATA, SegmentScope.VECTOR} + else: + assert len(scopes) == 3 + assert set(scopes) == { + SegmentScope.METADATA, + SegmentScope.RECORD, + SegmentScope.VECTOR, + } + else: + with pytest.raises(Exception): + self.sysdb.get_collection_with_segments(collection_id=coll.id) + + @rule(coll=consumes(collections)) + def delete_collection(self, coll: strategies.Collection) -> None: + if ( + coll.name in self.created_collections + and coll.id == self.created_collections[coll.name].id + ): + # TODO: Convert collection views used in tests into actual Collections / Collection models + self.sysdb.delete_collection(coll.id) + self.created_collections.pop(coll.name) + else: + with pytest.raises(Exception): + self.sysdb.delete_collection(id=coll.id) + + +def test_sysdb(caplog: pytest.LogCaptureFixture, system: System) -> None: + sysdb = next(sqlite()) if NOT_CLUSTER_ONLY else next(grpc_with_real_server()) + run_state_machine_as_test(lambda: SysDBStateMachine(sysdb=sysdb)) # type: ignore[no-untyped-call] diff --git a/chromadb/test/segment/distributed/test_protobuf_translation.py b/chromadb/test/segment/distributed/test_protobuf_translation.py index 6fd3777abec..d29fcad0365 100644 --- a/chromadb/test/segment/distributed/test_protobuf_translation.py +++ b/chromadb/test/segment/distributed/test_protobuf_translation.py @@ -1,68 +1,126 @@ import uuid - -from chromadb.config import Settings, System -from chromadb.segment.impl.metadata.grpc_segment import GrpcMetadataSegment +from chromadb.proto import convert +from chromadb.segment import SegmentType from chromadb.types import ( + Collection, + CollectionConfigurationInternal, Segment, SegmentScope, Where, WhereDocument, - MetadataEmbeddingRecord, ) import chromadb.proto.chroma_pb2 as pb +import chromadb.proto.query_executor_pb2 as query_pb +def test_collection_to_proto() -> None: + collection = Collection( + id=uuid.uuid4(), + name="test_collection", + configuration=CollectionConfigurationInternal(), + metadata={"hnsw_m": 128}, + dimension=512, + tenant="test_tenant", + database="test_database", + version=1, + log_position=42, + ) + + assert convert.to_proto_collection(collection) == pb.Collection( + id=collection.id.hex, + name="test_collection", + configuration_json_str=CollectionConfigurationInternal().to_json_str(), + metadata=pb.UpdateMetadata(metadata={"hnsw_m": pb.UpdateMetadataValue(int_value=128)}), + dimension=512, + tenant="test_tenant", + database="test_database", + version=1, + log_position=42, + ) -# Note: trying to start() this segment will cause it to error since it doesn't -# have a remote server to talk to. This is only suitable for testing the -# python <-> proto translation logic. -def unstarted_grpc_metadata_segment() -> GrpcMetadataSegment: - settings = Settings( - allow_reset=True, +def test_collection_from_proto() -> None: + proto = pb.Collection( + id=uuid.uuid4().hex, + name="test_collection", + configuration_json_str=CollectionConfigurationInternal().to_json_str(), + metadata=pb.UpdateMetadata(metadata={"hnsw_m": pb.UpdateMetadataValue(int_value=128)}), + dimension=512, + tenant="test_tenant", + database="test_database", + version=1, + log_position=42, ) - system = System(settings) + assert convert.from_proto_collection(proto) == Collection( + id=uuid.UUID(proto.id), + name="test_collection", + configuration=CollectionConfigurationInternal(), + metadata={"hnsw_m": 128}, + dimension=512, + tenant="test_tenant", + database="test_database", + version=1, + log_position=42, + ) + +def test_segment_to_proto() -> None: segment = Segment( id=uuid.uuid4(), - type="test", - scope=SegmentScope.METADATA, + type=SegmentType.HNSW_DISTRIBUTED.value, + scope=SegmentScope.VECTOR, collection=uuid.uuid4(), - metadata={ - "grpc_url": "test", - }, + metadata={"hnsw_m": 128}, + file_paths={"name": ["path_0", "path_1"]}, ) - grpc_metadata_segment = GrpcMetadataSegment( - system=system, - segment=segment, + assert convert.to_proto_segment(segment) == pb.Segment( + id=segment["id"].hex, + type=SegmentType.HNSW_DISTRIBUTED.value, + scope=pb.SegmentScope.VECTOR, + collection=segment["collection"].hex, + metadata=pb.UpdateMetadata(metadata={"hnsw_m": pb.UpdateMetadataValue(int_value=128)}), + file_paths={"name": pb.FilePaths(paths=["path_0", "path_1"])}, ) - return grpc_metadata_segment +def test_segment_from_proto() -> None: + proto = pb.Segment( + id=uuid.uuid4().hex, + type=SegmentType.HNSW_DISTRIBUTED.value, + scope=pb.SegmentScope.VECTOR, + collection=uuid.uuid4().hex, + metadata=pb.UpdateMetadata(metadata={"hnsw_m": pb.UpdateMetadataValue(int_value=128)}), + file_paths={"name": pb.FilePaths(paths=["path_0", "path_1"])}, + ) + assert convert.from_proto_segment(proto) == Segment( + id=uuid.UUID(proto.id), + type=SegmentType.HNSW_DISTRIBUTED.value, + scope=SegmentScope.VECTOR, + collection=uuid.UUID(proto.collection), + metadata={"hnsw_m": 128}, + file_paths={"name": ["path_0", "path_1"]}, + ) def test_where_document_to_proto_not_contains() -> None: - md_segment = unstarted_grpc_metadata_segment() where_document: WhereDocument = {"$not_contains": "test"} - proto = md_segment._where_document_to_proto(where_document) + proto = convert.to_proto_where_document(where_document) assert proto.HasField("direct") assert proto.direct.document == "test" assert proto.direct.operator == pb.WhereDocumentOperator.NOT_CONTAINS def test_where_document_to_proto_contains_to_proto() -> None: - md_segment = unstarted_grpc_metadata_segment() where_document: WhereDocument = {"$contains": "test"} - proto = md_segment._where_document_to_proto(where_document) + proto = convert.to_proto_where_document(where_document) assert proto.HasField("direct") assert proto.direct.document == "test" assert proto.direct.operator == pb.WhereDocumentOperator.CONTAINS def test_where_document_to_proto_and() -> None: - md_segment = unstarted_grpc_metadata_segment() where_document: WhereDocument = { "$and": [ {"$contains": "test"}, {"$not_contains": "test"}, ] } - proto = md_segment._where_document_to_proto(where_document) + proto = convert.to_proto_where_document(where_document) assert proto.HasField("children") children_pb = proto.children assert children_pb.operator == pb.BooleanOperator.AND @@ -78,14 +136,13 @@ def test_where_document_to_proto_and() -> None: def test_where_document_to_proto_or() -> None: - md_segment = unstarted_grpc_metadata_segment() where_document: WhereDocument = { "$or": [ {"$contains": "test"}, {"$not_contains": "test"}, ] } - proto = md_segment._where_document_to_proto(where_document) + proto = convert.to_proto_where_document(where_document) assert proto.HasField("children") children_pb = proto.children assert children_pb.operator == pb.BooleanOperator.OR @@ -101,7 +158,6 @@ def test_where_document_to_proto_or() -> None: def test_where_document_to_proto_nested_boolean_operators() -> None: - md_segment = unstarted_grpc_metadata_segment() where_document: WhereDocument = { "$and": [ { @@ -118,7 +174,7 @@ def test_where_document_to_proto_nested_boolean_operators() -> None: }, ] } - proto = md_segment._where_document_to_proto(where_document) + proto = convert.to_proto_where_document(where_document) assert proto.HasField("children") children_pb = proto.children assert children_pb.operator == pb.BooleanOperator.AND @@ -141,11 +197,10 @@ def test_where_document_to_proto_nested_boolean_operators() -> None: def test_where_to_proto_string_value() -> None: - md_segment = unstarted_grpc_metadata_segment() where: Where = { "test": "value", } - proto: pb.Where = md_segment._where_to_proto(where) + proto = convert.to_proto_where(where) assert proto.HasField("direct_comparison") d = proto.direct_comparison assert d.key == "test" @@ -154,11 +209,10 @@ def test_where_to_proto_string_value() -> None: def test_where_to_proto_int_value() -> None: - md_segment = unstarted_grpc_metadata_segment() where: Where = { "test": 1, } - proto = md_segment._where_to_proto(where) + proto = convert.to_proto_where(where) assert proto.HasField("direct_comparison") d = proto.direct_comparison assert d.key == "test" @@ -167,11 +221,10 @@ def test_where_to_proto_int_value() -> None: def test_where_to_proto_double_value() -> None: - md_segment = unstarted_grpc_metadata_segment() where: Where = { "test": 1.0, } - proto = md_segment._where_to_proto(where) + proto = convert.to_proto_where(where) assert proto.HasField("direct_comparison") d = proto.direct_comparison assert d.key == "test" @@ -180,14 +233,13 @@ def test_where_to_proto_double_value() -> None: def test_where_to_proto_and() -> None: - md_segment = unstarted_grpc_metadata_segment() where: Where = { "$and": [ {"test": 1}, {"test": "value"}, ] } - proto = md_segment._where_to_proto(where) + proto = convert.to_proto_where(where) assert proto.HasField("children") children_pb = proto.children assert children_pb.operator == pb.BooleanOperator.AND @@ -205,14 +257,13 @@ def test_where_to_proto_and() -> None: def test_where_to_proto_or() -> None: - md_segment = unstarted_grpc_metadata_segment() where: Where = { "$or": [ {"test": 1}, {"test": "value"}, ] } - proto = md_segment._where_to_proto(where) + proto = convert.to_proto_where(where) assert proto.HasField("children") children_pb = proto.children assert children_pb.operator == pb.BooleanOperator.OR @@ -230,7 +281,6 @@ def test_where_to_proto_or() -> None: def test_where_to_proto_nested_boolean_operators() -> None: - md_segment = unstarted_grpc_metadata_segment() where: Where = { "$and": [ { @@ -247,7 +297,7 @@ def test_where_to_proto_nested_boolean_operators() -> None: }, ] } - proto = md_segment._where_to_proto(where) + proto = convert.to_proto_where(where) assert proto.HasField("children") children_pb = proto.children assert children_pb.operator == pb.BooleanOperator.AND @@ -272,14 +322,13 @@ def test_where_to_proto_nested_boolean_operators() -> None: def test_where_to_proto_float_operator() -> None: - md_segment = unstarted_grpc_metadata_segment() where: Where = { "$and": [ {"test1": 1.0}, {"test2": 2.0}, ] } - proto = md_segment._where_to_proto(where) + proto = convert.to_proto_where(where) assert proto.HasField("children") children_pb = proto.children assert children_pb.operator == pb.BooleanOperator.AND @@ -299,89 +348,29 @@ def test_where_to_proto_float_operator() -> None: assert child_1.direct_comparison.single_double_operand.value == 2.0 -def test_metadata_embedding_record_string_from_proto() -> None: - md_segment = unstarted_grpc_metadata_segment() - val: pb.UpdateMetadataValue = pb.UpdateMetadataValue( - string_value="test_value", - ) - update: pb.UpdateMetadata = pb.UpdateMetadata( - metadata={"test_key": val}, - ) - record: pb.MetadataEmbeddingRecord = pb.MetadataEmbeddingRecord( - id="test_id", - metadata=update, - ) - - mdr: MetadataEmbeddingRecord = md_segment._from_proto(record) - assert mdr["id"] == "test_id" - assert mdr["metadata"] - assert mdr["metadata"]["test_key"] == "test_value" - - -def test_metadata_embedding_record_int_from_proto() -> None: - md_segment = unstarted_grpc_metadata_segment() - val: pb.UpdateMetadataValue = pb.UpdateMetadataValue( - int_value=1, - ) - update: pb.UpdateMetadata = pb.UpdateMetadata( - metadata={"test_key": val}, - ) - record: pb.MetadataEmbeddingRecord = pb.MetadataEmbeddingRecord( - id="test_id", - metadata=update, - ) - - mdr: MetadataEmbeddingRecord = md_segment._from_proto(record) - assert mdr["id"] == "test_id" - assert mdr["metadata"] - assert mdr["metadata"]["test_key"] == 1 - - -def test_metadata_embedding_record_double_from_proto() -> None: - md_segment = unstarted_grpc_metadata_segment() - val: pb.UpdateMetadataValue = pb.UpdateMetadataValue( +def test_projection_record_from_proto() -> None: + float_val: pb.UpdateMetadataValue = pb.UpdateMetadataValue( float_value=1.0, ) - update: pb.UpdateMetadata = pb.UpdateMetadata( - metadata={"test_key": val}, + int_val: pb.UpdateMetadataValue = pb.UpdateMetadataValue( + int_value=2, ) - record: pb.MetadataEmbeddingRecord = pb.MetadataEmbeddingRecord( - id="test_id", - metadata=update, - ) - - mdr: MetadataEmbeddingRecord = md_segment._from_proto(record) - assert mdr["id"] == "test_id" - assert mdr["metadata"] - assert mdr["metadata"]["test_key"] == 1.0 - - -def test_metadata_embedding_record_heterogeneous_from_proto() -> None: - md_segment = unstarted_grpc_metadata_segment() - val1: pb.UpdateMetadataValue = pb.UpdateMetadataValue( - string_value="test_value", - ) - val2: pb.UpdateMetadataValue = pb.UpdateMetadataValue( - int_value=1, - ) - val3: pb.UpdateMetadataValue = pb.UpdateMetadataValue( - float_value=1.0, + str_val: pb.UpdateMetadataValue = pb.UpdateMetadataValue( + string_value="three", ) update: pb.UpdateMetadata = pb.UpdateMetadata( - metadata={ - "test_key1": val1, - "test_key2": val2, - "test_key3": val3, - }, + metadata={"float_key": float_val, "int_key": int_val, "str_key": str_val}, ) - record: pb.MetadataEmbeddingRecord = pb.MetadataEmbeddingRecord( + record: query_pb.ProjectionRecord = query_pb.ProjectionRecord( id="test_id", + document="document", metadata=update, ) - mdr: MetadataEmbeddingRecord = md_segment._from_proto(record) - assert mdr["id"] == "test_id" - assert mdr["metadata"] - assert mdr["metadata"]["test_key1"] == "test_value" - assert mdr["metadata"]["test_key2"] == 1 - assert mdr["metadata"]["test_key3"] == 1.0 + projection_record = convert.from_proto_projection_record(record) + + assert projection_record["id"] == "test_id" + assert projection_record["metadata"] + assert projection_record["metadata"]["float_key"] == 1.0 + assert projection_record["metadata"]["int_key"] == 2 + assert projection_record["metadata"]["str_key"] == "three" diff --git a/chromadb/test/segment/test_metadata.py b/chromadb/test/segment/test_metadata.py index 50bab861800..bb0f40e4234 100644 --- a/chromadb/test/segment/test_metadata.py +++ b/chromadb/test/segment/test_metadata.py @@ -129,6 +129,7 @@ def _build_document(i: int) -> str: scope=SegmentScope.METADATA, collection=uuid.UUID(int=0), metadata=None, + file_paths={}, ) segment_definition2 = Segment( @@ -137,6 +138,7 @@ def _build_document(i: int) -> str: scope=SegmentScope.METADATA, collection=uuid.UUID(int=1), metadata=None, + file_paths={}, ) diff --git a/chromadb/test/segment/test_vector.py b/chromadb/test/segment/test_vector.py index 0d62c827461..87375940276 100644 --- a/chromadb/test/segment/test_vector.py +++ b/chromadb/test/segment/test_vector.py @@ -115,6 +115,7 @@ def create_random_segment_definition() -> Segment: scope=SegmentScope.VECTOR, collection=uuid.UUID(int=0), metadata=test_hnsw_config, + file_paths={}, ) diff --git a/chromadb/types.py b/chromadb/types.py index 2dc98b826fc..363f4cacf83 100644 --- a/chromadb/types.py +++ b/chromadb/types.py @@ -176,7 +176,11 @@ class Segment(TypedDict): scope: SegmentScope collection: UUID metadata: Optional[Metadata] + file_paths: Mapping[str, Sequence[str]] +class CollectionAndSegments(TypedDict): + collection: Collection + segments: Sequence[Segment] # SeqID can be one of three types of value in our current and future plans: # 1. A Pulsar MessageID encoded as a 192-bit integer - This is no longer used as we removed pulsar diff --git a/clients/js/src/ChromaClient.ts b/clients/js/src/ChromaClient.ts index 415d72193a1..ac6c9bc170d 100644 --- a/clients/js/src/ChromaClient.ts +++ b/clients/js/src/ChromaClient.ts @@ -290,7 +290,7 @@ export class ChromaClient { /** * Lists all collections. * - * @returns {Promise} A promise that resolves to a list of collection names. + * @returns {Promise} A promise that resolves to a list of collection names. * @param {PositiveInteger} [params.limit] - Optional limit on the number of items to get. * @param {PositiveInteger} [params.offset] - Optional offset on the items to get. * @throws {Error} If there is an issue listing the collections. @@ -304,16 +304,17 @@ export class ChromaClient { * ``` */ async listCollections({ limit, offset }: ListCollectionsParams = {}): Promise< - CollectionParams[] + string[] > { await this.init(); - return (await this.api.listCollections( + const collections = (await this.api.listCollections( this.tenant, this.database, limit, offset, this.api.options, - )) as CollectionParams[]; + )) as Collection[]; + return collections.map((collection: Collection) => collection.name); } /** diff --git a/clients/js/src/index.ts b/clients/js/src/index.ts index 36c439ead7b..fe9e64e2f18 100644 --- a/clients/js/src/index.ts +++ b/clients/js/src/index.ts @@ -1,6 +1,7 @@ export { ChromaClient } from "./ChromaClient"; export { AdminClient } from "./AdminClient"; export { CloudClient } from "./CloudClient"; +export { Collection } from "./Collection"; export type { IEmbeddingFunction } from "./embeddings/IEmbeddingFunction"; export { OpenAIEmbeddingFunction } from "./embeddings/OpenAIEmbeddingFunction"; export { CohereEmbeddingFunction } from "./embeddings/CohereEmbeddingFunction"; @@ -14,7 +15,6 @@ export { OllamaEmbeddingFunction } from "./embeddings/OllamaEmbeddingFunction"; export type { IncludeEnum, GetParams, - CollectionType, CollectionMetadata, Embedding, Embeddings, diff --git a/clients/js/src/types.ts b/clients/js/src/types.ts index 88ffa29da92..67487b73049 100644 --- a/clients/js/src/types.ts +++ b/clients/js/src/types.ts @@ -54,13 +54,6 @@ export type WhereDocument = { | WhereDocument[]; }; -export type CollectionType = { - name: string; - id: string; - metadata: Metadata | null; - configuration_json: any; -}; - export type MultiGetResponse = { ids: IDs; embeddings: Embeddings | null; diff --git a/clients/js/test/collection.client.test.ts b/clients/js/test/collection.client.test.ts index 22b3a3cbb89..22cb63edcec 100644 --- a/clients/js/test/collection.client.test.ts +++ b/clients/js/test/collection.client.test.ts @@ -3,13 +3,9 @@ import { test, beforeEach, describe, - afterAll, - beforeAll, } from "@jest/globals"; -import { DefaultEmbeddingFunction } from "../src/embeddings/DefaultEmbeddingFunction"; -import { StartedTestContainer } from "testcontainers"; -import { ChromaClient } from "../src/ChromaClient"; -import { startChromaContainer } from "./startChromaContainer"; +import { DefaultEmbeddingFunction } from "../src"; +import { ChromaClient } from "../src"; describe("collection operations", () => { // connects to the unauthenticated chroma instance started in @@ -42,23 +38,7 @@ describe("collection operations", () => { const [returnedCollection] = collections; - expect({ - ...returnedCollection, - configuration_json: undefined, - id: undefined, - }).toMatchInlineSnapshot(` - { - "configuration_json": undefined, - "database": "default_database", - "dimension": null, - "id": undefined, - "log_position": 0, - "metadata": null, - "name": "test", - "tenant": "default_tenant", - "version": 0, - } - `); + expect(returnedCollection).toEqual("test") expect([{ name: "test2", metadata: null }]).not.toEqual( expect.arrayContaining(collections), @@ -79,25 +59,7 @@ describe("collection operations", () => { const collections2 = await client.listCollections(); expect(collections2).toHaveLength(1); const [returnedCollection2] = collections2; - expect({ - ...returnedCollection2, - configuration_json: undefined, - id: undefined, - }).toMatchInlineSnapshot(` - { - "configuration_json": undefined, - "database": "default_database", - "dimension": null, - "id": undefined, - "log_position": 0, - "metadata": { - "test": "test", - }, - "name": "test2", - "tenant": "default_tenant", - "version": 0, - } - `); + expect(returnedCollection2).toEqual("test2"); }); test("it should get a collection", async () => { diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index d7116d9ed34..0a7d1f300c8 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -8,7 +8,7 @@ authors = [ ] description = "Chroma Client." readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", @@ -32,7 +32,7 @@ dependencies = [ [tool.black] line-length = 88 required-version = "23.3.0" # Black will refuse to run if it's not this version. -target-version = ['py38', 'py39', 'py310', 'py311'] +target-version = ['py39', 'py310', 'py311'] [tool.pytest.ini_options] pythonpath = ["."] diff --git a/deployments/aws/chroma.cf.json b/deployments/aws/chroma.cf.json index dc1b9b301ac..da1fdfeb8d6 100644 --- a/deployments/aws/chroma.cf.json +++ b/deployments/aws/chroma.cf.json @@ -16,7 +16,7 @@ "ChromaVersion": { "Description": "Chroma version to install", "Type": "String", - "Default": "0.5.23" + "Default": "0.6.0" }, "ChromaServerAuthCredentials": { "Description": "Chroma authentication credentials", diff --git a/deployments/azure/chroma.tfvars.tf b/deployments/azure/chroma.tfvars.tf index b1446a67379..8228b7d66e1 100644 --- a/deployments/azure/chroma.tfvars.tf +++ b/deployments/azure/chroma.tfvars.tf @@ -4,7 +4,7 @@ machine_type = "Standard_B1s" # Azure VM size ssh_public_key_path = "~/.ssh/id_rsa.pub" # Path to your SSH public key instance_name = "chroma-instance" -chroma_version = "0.5.23" +chroma_version = "0.6.0" chroma_server_auth_credentials = "" chroma_server_auth_provider = "" chroma_auth_token_transport_header = "" diff --git a/deployments/azure/main.tf b/deployments/azure/main.tf index 0228360e3bc..22cdff5b4e1 100644 --- a/deployments/azure/main.tf +++ b/deployments/azure/main.tf @@ -17,7 +17,7 @@ variable "machine_type" { variable "chroma_version" { description = "Chroma version to install" - default = "0.5.23" + default = "0.6.0" } variable "chroma_server_auth_credentials" { diff --git a/deployments/gcp/chroma.tfvars b/deployments/gcp/chroma.tfvars index 353691da87f..41fa7c5cf82 100644 --- a/deployments/gcp/chroma.tfvars +++ b/deployments/gcp/chroma.tfvars @@ -2,7 +2,7 @@ project_id = "your-gcp-project-id" region = "your-region" # e.g., "us-central1" zone = "your-zone" machine_type = "" -chroma_version = "0.5.23" +chroma_version = "0.6.0" chroma_server_auth_credentials = "" chroma_server_auth_provider = "" chroma_auth_token_transport_header = "" diff --git a/deployments/gcp/main.tf b/deployments/gcp/main.tf index f2576ebec9f..8b4dee0419e 100644 --- a/deployments/gcp/main.tf +++ b/deployments/gcp/main.tf @@ -23,7 +23,7 @@ variable "machine_type" { variable "chroma_version" { description = "Chroma version to install" - default = "0.5.23" + default = "0.6.0" } variable "chroma_server_auth_credentials" { diff --git a/docs/docs.trychroma.com/components/header/header.tsx b/docs/docs.trychroma.com/components/header/header.tsx index 9f56381612a..c04a6a6b278 100644 --- a/docs/docs.trychroma.com/components/header/header.tsx +++ b/docs/docs.trychroma.com/components/header/header.tsx @@ -8,6 +8,7 @@ import Link from "next/link"; import SearchBox from "@/components/header/search-box"; import SearchDocs from "@/components/header/search-docs"; + const Header: React.FC = () => { return (
diff --git a/docs/docs.trychroma.com/components/markdoc/markdoc-heading.tsx b/docs/docs.trychroma.com/components/markdoc/markdoc-heading.tsx index 2812367c3e4..aa9508ed343 100644 --- a/docs/docs.trychroma.com/components/markdoc/markdoc-heading.tsx +++ b/docs/docs.trychroma.com/components/markdoc/markdoc-heading.tsx @@ -7,7 +7,6 @@ const generateId = (content: React.ReactNode): string => { .replaceAll("_", "-") .replace(/[^a-z0-9\s-]/g, "") .replace(/\s+/g, "-") - .trim(); } return ""; diff --git a/docs/docs.trychroma.com/markdoc/content/docs/guides/embeddings-guide.md b/docs/docs.trychroma.com/markdoc/content/docs/guides/embeddings-guide.md new file mode 100644 index 00000000000..5db0bcf7ddb --- /dev/null +++ b/docs/docs.trychroma.com/markdoc/content/docs/guides/embeddings-guide.md @@ -0,0 +1,123 @@ +--- +{ + "id": "embeddings-guide", + "title": "Embeddings", + "section": "Guides", + "order": 1 +} +--- + +# Embeddings + +Embeddings are the A.I-native way to represent any kind of data, making them the perfect fit for working with all kinds of A.I-powered tools and algorithms. They can represent text, images, and soon audio and video. There are many options for creating embeddings, whether locally using an installed library, or by calling an API. + +Chroma provides lightweight wrappers around popular embedding providers, making it easy to use them in your apps. You can set an embedding function when you create a Chroma collection, which will be used automatically, or you can call them directly yourself. + +{% special_table %} +{% /special_table %} + +| | Python | JS | +|--------------|-----------|---------------| +| [OpenAI](/integrations/openai) | ✅ | ✅ | +| [Google Generative AI](/integrations/google-gemini) | ✅ | ✅ | +| [Cohere](/integrations/cohere) | ✅ | ✅ | +| [Hugging Face](/integrations/hugging-face) | ✅ | ➖ | +| [Instructor](/integrations/instructor) | ✅ | ➖ | +| [Hugging Face Embedding Server](/integrations/hugging-face-server) | ✅ | ✅ | +| [Jina AI](/integrations/jinaai) | ✅ | ✅ | + +We welcome pull requests to add new Embedding Functions to the community. + +*** + +## Default: all-MiniLM-L6-v2 + +By default, Chroma uses the [Sentence Transformers](https://www.sbert.net/) `all-MiniLM-L6-v2` model to create embeddings. This embedding model can create sentence and document embeddings that can be used for a wide variety of tasks. This embedding function runs locally on your machine, and may require you download the model files (this will happen automatically). + +```python +from chromadb.utils import embedding_functions +default_ef = embedding_functions.DefaultEmbeddingFunction() +``` + +{% note type="default" %} +Embedding functions can be linked to a collection and used whenever you call `add`, `update`, `upsert` or `query`. You can also use them directly which can be handy for debugging. +```py +val = default_ef(["foo"]) +``` +-> [[0.05035809800028801, 0.0626462921500206, -0.061827320605516434...]] +{% /note %} + + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +## Sentence Transformers + +Chroma can also use any [Sentence Transformers](https://www.sbert.net/) model to create embeddings. + +```python +sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2") +``` + +You can pass in an optional `model_name` argument, which lets you choose which Sentence Transformers model to use. By default, Chroma uses `all-MiniLM-L6-v2`. You can see a list of all available models [here](https://www.sbert.net/docs/pretrained_models.html). + +{% /tab %} +{% tab label="Javascript" %} +{% /tab %} +{% /tabs %} + + +*** + + +## Custom Embedding Functions + +{% tabs group="code-lang" hideContent=true %} + +{% tab label="Python" %} +{% /tab %} + +{% tab label="Javascript" %} +{% /tab %} + +{% /tabs %} + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +You can create your own embedding function to use with Chroma, it just needs to implement the `EmbeddingFunction` protocol. + +```python +from chromadb import Documents, EmbeddingFunction, Embeddings + +class MyEmbeddingFunction(EmbeddingFunction): + def __call__(self, input: Documents) -> Embeddings: + # embed the documents somehow + return embeddings +``` + +We welcome contributions! If you create an embedding function that you think would be useful to others, please consider [submitting a pull request](https://github.com/chroma-core/chroma) to add it to Chroma's `embedding_functions` module. + + +{% /tab %} +{% tab label="Javascript" %} + +You can create your own embedding function to use with Chroma, it just needs to implement the `EmbeddingFunction` protocol. The `.generate` method in a class is strictly all you need. + +```javascript +class MyEmbeddingFunction { + private api_key: string; + + constructor(api_key: string) { + this.api_key = api_key; + } + + public async generate(texts: string[]): Promise { + // do things to turn texts into embeddings with an api_key perhaps + return embeddings; + } +} +``` + +{% /tab %} +{% /tabs %} diff --git a/docs/docs.trychroma.com/markdoc/content/docs/guides/multimodal-guide.md b/docs/docs.trychroma.com/markdoc/content/docs/guides/multimodal-guide.md new file mode 100644 index 00000000000..70d7b30e5dc --- /dev/null +++ b/docs/docs.trychroma.com/markdoc/content/docs/guides/multimodal-guide.md @@ -0,0 +1,158 @@ +--- +{ + "id": "multimodal-guide", + "title": "Multimodal", + "section": "Guides", + "order": 2 +} +--- + +# Multimodal + +{% tabs group="code-lang" hideContent=true %} + +{% tab label="Python" %} +{% /tab %} + +{% tab label="Javascript" %} +{% /tab %} + +{% /tabs %} + +--- + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +Chroma supports multimodal collections, i.e. collections which can store, and can be queried by, multiple modalities of data. + +Try it out in Colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/chroma-core/chroma/blob/main/examples/multimodal/multimodal_retrieval.ipynb) + +## Multi-modal Embedding Functions + +Chroma supports multi-modal embedding functions, which can be used to embed data from multiple modalities into a single embedding space. + +Chroma has the OpenCLIP embedding function built in, which supports both text and images. + +```python +from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction +embedding_function = OpenCLIPEmbeddingFunction() +``` + +## Data Loaders + +Chroma supports data loaders, for storing and querying with data stored outside Chroma itself, via URI. Chroma will not store this data, but will instead store the URI, and load the data from the URI when needed. + +Chroma has an data loader for loading images from a filesystem built in. + +```python +from chromadb.utils.data_loaders import ImageLoader +data_loader = ImageLoader() +``` + +## Multi-modal Collections + +You can create a multi-modal collection by passing in a multi-modal embedding function. In order to load data from a URI, you must also pass in a data loader. + +```python +import chromadb + +client = chromadb.Client() + +collection = client.create_collection( + name='multimodal_collection', + embedding_function=embedding_function, + data_loader=data_loader) + +``` + +### Adding data + +You can add data to a multi-modal collection by specifying the data modality. For now, images are supported: + +```python +collection.add( + ids=['id1', 'id2', 'id3'], + images=[...] # A list of numpy arrays representing images +) +``` + +Note that Chroma will not store the data for you, and you will have to maintain a mapping from IDs to data yourself. + +However, you can use Chroma in combination with data stored elsewhere, by adding it via URI. Note that this requires that you have specified a data loader when creating the collection. + +```python +collection.add( + ids=['id1', 'id2', 'id3'], + uris=[...] # A list of strings representing URIs to data +) +``` + +Since the embedding function is multi-modal, you can also add text to the same collection: + +```python +collection.add( + ids=['id4', 'id5', 'id6'], + documents=["This is a document", "This is another document", "This is a third document"] +) +``` + +### Querying + +You can query a multi-modal collection with any of the modalities that it supports. For example, you can query with images: + +```python +results = collection.query( + query_images=[...] # A list of numpy arrays representing images +) +``` + +Or with text: + +```python +results = collection.query( + query_texts=["This is a query document", "This is another query document"] +) +``` + +If a data loader is set for the collection, you can also query with URIs which reference data stored elsewhere of the supported modalities: + +```python +results = collection.query( + query_uris=[...] # A list of strings representing URIs to data +) +``` + +Additionally, if a data loader is set for the collection, and URIs are available, you can include the data in the results: + +```python +results = collection.query( + query_images=[...], # # list of numpy arrays representing images + includes=['data'] +) +``` + +This will automatically call the data loader for any available URIs, and include the data in the results. `uris` are also available as an `includes` field. + +### Updating + +You can update a multi-modal collection by specifying the data modality, in the same way as `add`. For now, images are supported: + +```python +collection.update( + ids=['id1', 'id2', 'id3'], + images=[...] # A list of numpy arrays representing images +) +``` + +Note that a given entry with a specific ID can only have one associated modality at a time. Updates will over-write the existing modality, so for example, an entry which originally has corresponding text and updated with an image, will no longer have that text after an update with images. + +{% /tab %} +{% tab label="Javascript" %} + +Support for multi-modal retrieval for Chroma's JavaScript client is coming soon! + +{% /tab %} + +{% /tabs %} + diff --git a/docs/docs.trychroma.com/markdoc/content/docs/guides/usage-guide.md b/docs/docs.trychroma.com/markdoc/content/docs/guides/usage-guide.md new file mode 100644 index 00000000000..ce37ff89487 --- /dev/null +++ b/docs/docs.trychroma.com/markdoc/content/docs/guides/usage-guide.md @@ -0,0 +1,851 @@ +--- +{ + "id": "usage-guide", + "title": "Usage Guide", + "section": "Guides", + "order": 0 +} +--- + + +# Usage Guide + + +{% tabs group="code-lang" hideContent=true %} + +{% tab label="Python" %} +{% /tab %} + +{% tab label="Javascript" %} +{% /tab %} + +{% /tabs %} + +--- + +## Initiating a persistent Chroma client + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +```python +import chromadb +``` + +You can configure Chroma to save and load the database from your local machine. Data will be persisted automatically and loaded on start (if it exists). + +```python +client = chromadb.PersistentClient(path="/path/to/save/to") +``` + +The `path` is where Chroma will store its database files on disk, and load them on start. + +{% /tab %} +{% tab label="Javascript" %} + +```js +// CJS +const { ChromaClient } = require("chromadb"); + +// ESM +import { ChromaClient } from "chromadb"; +``` + +{% note type="note" title="Connecting to the backend" %} +To connect with the JS client, you must connect to a backend running Chroma. See [Running Chroma in client-server mode](#running-chroma-in-client-server-mode) for how to do this. +{% /note %} + +```js +const client = new ChromaClient(); +``` + +{% /tab %} + +{% /tabs %} + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +The client object has a few useful convenience methods. + +```python +client.heartbeat() # returns a nanosecond heartbeat. Useful for making sure the client remains connected. +client.reset() # Empties and completely resets the database. ⚠️ This is destructive and not reversible. +``` + +{% /tab %} +{% tab label="Javascript" %} + +The client object has a few useful convenience methods. + +```javascript +await client.reset() # Empties and completely resets the database. ⚠️ This is destructive and not reversible. +``` + +{% /tab %} + +{% /tabs %} + +## Running Chroma in client-server mode + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +Chroma can also be configured to run in client/server mode. In this mode, the Chroma client connects to a Chroma server running in a separate process. + +To start the Chroma server, run the following command: + +```bash +chroma run --path /db_path +``` + +Then use the Chroma HTTP client to connect to the server: + +```python +import chromadb +chroma_client = chromadb.HttpClient(host='localhost', port=8000) +``` + +That's it! Chroma's API will run in `client-server` mode with just this change. + +--- + +Chroma also provides an async HTTP client. The behaviors and method signatures are identical to the synchronous client, but all methods that would block are now async. To use it, call `AsyncHttpClient` instead: + +```python +import asyncio +import chromadb + +async def main(): + client = await chromadb.AsyncHttpClient() + collection = await client.create_collection(name="my_collection") + + await collection.add( + documents=["hello world"], + ids=["id1"] + ) + +asyncio.run(main()) +``` + + + +#### Using the Python HTTP-only client + +If you are running Chroma in client-server mode, you may not need the full Chroma library. Instead, you can use the lightweight client-only library. +In this case, you can install the `chromadb-client` package. This package is a lightweight HTTP client for the server with a minimal dependency footprint. + +```python +pip install chromadb-client +``` + +```python +import chromadb +# Example setup of the client to connect to your chroma server +client = chromadb.HttpClient(host='localhost', port=8000) + +# Or for async usage: +async def main(): + client = await chromadb.AsyncHttpClient(host='localhost', port=8000) +``` + +Note that the `chromadb-client` package is a subset of the full Chroma library and does not include all the dependencies. If you want to use the full Chroma library, you can install the `chromadb` package instead. +Most importantly, there is no default embedding function. If you add() documents without embeddings, you must have manually specified an embedding function and installed the dependencies for it. + +{% /tab %} +{% tab label="Javascript" %} + +To run Chroma in client server mode, first install the chroma library and CLI via pypi: + +```bash +pip install chromadb +``` + +Then start the Chroma server: + +```bash +chroma run --path /db_path +``` + +The JS client then talks to the chroma server backend. + +```js +// CJS +const { ChromaClient } = require("chromadb"); + +// ESM +import { ChromaClient } from "chromadb"; + +const client = new ChromaClient(); +``` + +You can also run the Chroma server in a docker container, or deployed to a cloud provider. See the [deployment docs](./deployment.md) for more information. + +{% /tab %} + +{% /tabs %} + +## Using collections + +Chroma lets you manage collections of embeddings, using the `collection` primitive. + +### Creating, inspecting, and deleting Collections + +Chroma uses collection names in the url, so there are a few restrictions on naming them: + +- The length of the name must be between 3 and 63 characters. +- The name must start and end with a lowercase letter or a digit, and it can contain dots, dashes, and underscores in between. +- The name must not contain two consecutive dots. +- The name must not be a valid IP address. + +Chroma collections are created with a name and an optional embedding function. If you supply an embedding function, you must supply it every time you get the collection. + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +```python +collection = client.create_collection(name="my_collection", embedding_function=emb_fn) +collection = client.get_collection(name="my_collection", embedding_function=emb_fn) +``` + +{% note type="caution" %} +If you later wish to `get_collection`, you MUST do so with the embedding function you supplied while creating the collection +{% /note %} + +The embedding function takes text as input, and performs tokenization and embedding. If no embedding function is supplied, Chroma will use [sentence transformer](https://www.sbert.net/index.html) as a default. + +{% /tab %} +{% tab label="Javascript" %} + +```js +// CJS +const { ChromaClient } = require("chromadb"); + +// ESM +import { ChromaClient } from "chromadb"; +``` + +The JS client talks to a chroma server backend. This can run on your local computer or be easily deployed to AWS. + +```js +let collection = await client.createCollection({ + name: "my_collection", + embeddingFunction: emb_fn, +}); +let collection2 = await client.getCollection({ + name: "my_collection", + embeddingFunction: emb_fn, +}); +``` + +{% note type="caution" %} +If you later wish to `getCollection`, you MUST do so with the embedding function you supplied while creating the collection +{% /note %} + +The embedding function takes text as input, and performs tokenization and embedding. + +{% /tab %} + +{% /tabs %} + +You can learn more about [🧬 embedding functions](./guides/embeddings), and how to create your own. + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +Existing collections can be retrieved by name with `.get_collection`, and deleted with `.delete_collection`. You can also use `.get_or_create_collection` to get a collection if it exists, or create it if it doesn't. + +```python +collection = client.get_collection(name="test") # Get a collection object from an existing collection, by name. Will raise an exception if it's not found. +collection = client.get_or_create_collection(name="test") # Get a collection object from an existing collection, by name. If it doesn't exist, create it. +client.delete_collection(name="my_collection") # Delete a collection and all associated embeddings, documents, and metadata. ⚠️ This is destructive and not reversible +``` + +{% /tab %} +{% tab label="Javascript" %} + +Existing collections can be retrieved by name with `.getCollection`, and deleted with `.deleteCollection`. + +```javascript +const collection = await client.getCollection({ name: "test" }); // Get a collection object from an existing collection, by name. Will raise an exception of it's not found. +collection = await client.getOrCreateCollection({ name: "test" }); // Get a collection object from an existing collection, by name. If it doesn't exist, create it. +await client.deleteCollection(collection); // Delete a collection and all associated embeddings, documents, and metadata. ⚠️ This is destructive and not reversible +``` + +{% /tab %} + +{% /tabs %} + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +Collections have a few useful convenience methods. + +```python +collection.peek() # returns a list of the first 10 items in the collection +collection.count() # returns the number of items in the collection +collection.modify(name="new_name") # Rename the collection +``` + +{% /tab %} +{% tab label="Javascript" %} + +There are a few useful convenience methods for working with Collections. + +```javascript +await collection.peek(); // returns a list of the first 10 items in the collection +await collection.count(); // returns the number of items in the collection +``` + +{% /tab %} + +{% /tabs %} + +### Changing the distance function + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +`create_collection` also takes an optional `metadata` argument which can be used to customize the distance method of the embedding space by setting the value of `hnsw:space`. + +```python + collection = client.create_collection( + name="collection_name", + metadata={"hnsw:space": "cosine"} # l2 is the default + ) +``` + +{% /tab %} +{% tab label="Javascript" %} + +`createCollection` also takes an optional `metadata` argument which can be used to customize the distance method of the embedding space by setting the value of `hnsw:space` + +```js +let collection = client.createCollection({ + name: "collection_name", + metadata: { "hnsw:space": "cosine" }, +}); +``` + +{% /tab %} + +{% /tabs %} + +Valid options for `hnsw:space` are "l2", "ip, "or "cosine". The **default** is "l2" which is the squared L2 norm. + +{% special_table %} +{% /special_table %} + +| Distance | parameter | Equation | +| ----------------- | :-------: | ------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| Squared L2 | `l2` | {% math latexText="d = \\sum\\left(A_i-B_i\\right)^2" %}{% /math %} | +| Inner product | `ip` | {% math latexText="d = 1.0 - \\sum\\left(A_i \\times B_i\\right) " %}{% /math %} | +| Cosine similarity | `cosine` | {% math latexText="d = 1.0 - \\frac{\\sum\\left(A_i \\times B_i\\right)}{\\sqrt{\\sum\\left(A_i^2\\right)} \\cdot \\sqrt{\\sum\\left(B_i^2\\right)}}" %}{% /math %} | + +### Adding data to a Collection + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +Add data to Chroma with `.add`. + +Raw documents: + +```python +collection.add( + documents=["lorem ipsum...", "doc2", "doc3", ...], + metadatas=[{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}, ...], + ids=["id1", "id2", "id3", ...] +) +``` + +{% /tab %} +{% tab label="Javascript" %} + +Add data to Chroma with `.addRecords`. + +Raw documents: + +```javascript +await collection.add({ + ids: ["id1", "id2", "id3", ...], + metadatas: [{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}, ...], + documents: ["lorem ipsum...", "doc2", "doc3", ...], +}) +// input order +// ids - required +// embeddings - optional +// metadata - optional +// documents - optional +``` + +{% /tab %} + +{% /tabs %} + +If Chroma is passed a list of `documents`, it will automatically tokenize and embed them with the collection's embedding function (the default will be used if none was supplied at collection creation). Chroma will also store the `documents` themselves. If the documents are too large to embed using the chosen embedding function, an exception will be raised. + +Each document must have a unique associated `id`. Trying to `.add` the same ID twice will result in only the initial value being stored. An optional list of `metadata` dictionaries can be supplied for each document, to store additional information and enable filtering. + +Alternatively, you can supply a list of document-associated `embeddings` directly, and Chroma will store the associated documents without embedding them itself. + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +```python +collection.add( + documents=["doc1", "doc2", "doc3", ...], + embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2], ...], + metadatas=[{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}, ...], + ids=["id1", "id2", "id3", ...] +) +``` + +{% /tab %} +{% tab label="Javascript" %} + +```javascript +await collection.add({ + ids: ["id1", "id2", "id3", ...], + embeddings: [[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2], ...], + metadatas: [{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}, ...], + documents: ["lorem ipsum...", "doc2", "doc3", ...], +}) + +``` + +{% /tab %} + +{% /tabs %} + +If the supplied `embeddings` are not the same dimension as the collection, an exception will be raised. + +You can also store documents elsewhere, and just supply a list of `embeddings` and `metadata` to Chroma. You can use the `ids` to associate the embeddings with your documents stored elsewhere. + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +```python +collection.add( + ids=["id1", "id2", "id3", ...], + embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2], ...], + metadatas=[{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}, ...] +) +``` + +{% /tab %} +{% tab label="Javascript" %} + +```javascript +await collection.add({ + ids: ["id1", "id2", "id3", ...], + embeddings: [[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2], ...], + metadatas: [{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}, ...], +}) +``` + +{% /tab %} + +{% /tabs %} + +### Querying a Collection + +You can query by a set of `query_embeddings`. + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +Chroma collections can be queried in a variety of ways, using the `.query` method. + +```python +collection.query( + query_embeddings=[[11.1, 12.1, 13.1],[1.1, 2.3, 3.2], ...], + n_results=10, + where={"metadata_field": "is_equal_to_this"}, + where_document={"$contains":"search_string"} +) +``` + +{% /tab %} +{% tab label="Javascript" %} + +Chroma collections can be queried in a variety of ways, using the `.queryRecords` method. + +```javascript +const result = await collection.query({ + queryEmbeddings: [[11.1, 12.1, 13.1],[1.1, 2.3, 3.2], ...], + nResults: 10, + where: {"metadata_field": "is_equal_to_this"}, +}) +// input order +// queryEmbeddings - optional, exactly one of queryEmbeddings and queryTexts must be provided +// queryTexts - optional +// n_results - required +// where - optional +``` + +{% /tab %} + +{% /tabs %} + +The query will return the `n_results` closest matches to each `query_embedding`, in order. +An optional `where` filter dictionary can be supplied to filter by the `metadata` associated with each document. +Additionally, an optional `where_document` filter dictionary can be supplied to filter by contents of the document. + +If the supplied `query_embeddings` are not the same dimension as the collection, an exception will be raised. + +You can also query by a set of `query_texts`. Chroma will first embed each `query_text` with the collection's embedding function, and then perform the query with the generated embedding. + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +```python +collection.query( + query_texts=["doc10", "thus spake zarathustra", ...], + n_results=10, + where={"metadata_field": "is_equal_to_this"}, + where_document={"$contains":"search_string"} +) +``` + +You can also retrieve items from a collection by `id` using `.get`. + +```python +collection.get( + ids=["id1", "id2", "id3", ...], + where={"style": "style1"} +) +``` + +{% /tab %} +{% tab label="Javascript" %} + +```javascript +await collection.query({ + nResults: 10, // n_results + where: {"metadata_field": "is_equal_to_this"}, // where + queryTexts: ["doc10", "thus spake zarathustra", ...], // query_text +}) +``` + +You can also retrieve records from a collection by `id` using `.getRecords`. + +```javascript +await collection.get( { + ids: ["id1", "id2", "id3", ...], //ids + where: {"style": "style1"} // where +}) +``` + +{% /tab %} + +{% /tabs %} + +`.get` also supports the `where` and `where_document` filters. If no `ids` are supplied, it will return all items in the collection that match the `where` and `where_document` filters. + +##### Choosing which data is returned + +When using get or query you can use the include parameter to specify which data you want returned - any of `embeddings`, `documents`, `metadatas`, and for query, `distances`. By default, Chroma will return the `documents`, `metadatas` and in the case of query, the `distances` of the results. `embeddings` are excluded by default for performance and the `ids` are always returned. You can specify which of these you want returned by passing an array of included field names to the includes parameter of the query or get method. Note that embeddings will be returned as a 2-d numpy array in `.get` and a python list of 2-d numpy arrays in `.query`. + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +```python +# Only get documents and ids +collection.get( + include=["documents"] +) + +collection.query( + query_embeddings=[[11.1, 12.1, 13.1],[1.1, 2.3, 3.2], ...], + include=["documents"] +) +``` + +{% /tab %} +{% tab label="Javascript" %} + +```javascript +# Only get documents and ids +collection.get( + {include=["documents"]} +) + +collection.get({ + queryEmbeddings=[[11.1, 12.1, 13.1],[1.1, 2.3, 3.2], ...], + include=["documents"] +}) +``` + +{% /tab %} + +{% /tabs %} + +### Using Where filters + +Chroma supports filtering queries by `metadata` and `document` contents. The `where` filter is used to filter by `metadata`, and the `where_document` filter is used to filter by `document` contents. + +##### Filtering by metadata + +In order to filter on metadata, you must supply a `where` filter dictionary to the query. The dictionary must have the following structure: + +```python +{ + "metadata_field": { + : + } +} +``` + +Filtering metadata supports the following operators: + +- `$eq` - equal to (string, int, float) +- `$ne` - not equal to (string, int, float) +- `$gt` - greater than (int, float) +- `$gte` - greater than or equal to (int, float) +- `$lt` - less than (int, float) +- `$lte` - less than or equal to (int, float) + +Using the $eq operator is equivalent to using the `where` filter. + +```python +{ + "metadata_field": "search_string" +} + +# is equivalent to + +{ + "metadata_field": { + "$eq": "search_string" + } +} +``` + +{% note type="note" %} +Where filters only search embeddings where the key exists. If you search `collection.get(where={"version": {"$ne": 1}})`. Metadata that does not have the key `version` will not be returned. +{% /note %} + +##### Filtering by document contents + +In order to filter on document contents, you must supply a `where_document` filter dictionary to the query. We support two filtering keys: `$contains` and `$not_contains`. The dictionary must have the following structure: + +```python +# Filtering for a search_string +{ + "$contains": "search_string" +} +``` + +```python +# Filtering for not contains +{ + "$not_contains": "search_string" +} +``` + +##### Using logical operators + +You can also use the logical operators `$and` and `$or` to combine multiple filters. + +An `$and` operator will return results that match all of the filters in the list. + +```python +{ + "$and": [ + { + "metadata_field": { + : + } + }, + { + "metadata_field": { + : + } + } + ] +} +``` + +An `$or` operator will return results that match any of the filters in the list. + +```python +{ + "$or": [ + { + "metadata_field": { + : + } + }, + { + "metadata_field": { + : + } + } + ] +} +``` + +##### Using inclusion operators (`$in` and `$nin`) + +The following inclusion operators are supported: + +- `$in` - a value is in predefined list (string, int, float, bool) +- `$nin` - a value is not in predefined list (string, int, float, bool) + +An `$in` operator will return results where the metadata attribute is part of a provided list: + +```json +{ + "metadata_field": { + "$in": ["value1", "value2", "value3"] + } +} +``` + +An `$nin` operator will return results where the metadata attribute is not part of a provided list: + +```json +{ + "metadata_field": { + "$nin": ["value1", "value2", "value3"] + } +} +``` + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +{% note type="note" title="Practical examples" %} +For additional examples and a demo how to use the inclusion operators, please see provided notebook [here](https://github.com/chroma-core/chroma/blob/main/examples/basic_functionality/in_not_in_filtering.ipynb) +{% /note %} + +{% /tab %} +{% tab label="Javascript" %} +{% /tab %} + +{% /tabs %} + +### Updating data in a collection + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +Any property of records in a collection can be updated using `.update`. + +```python +collection.update( + ids=["id1", "id2", "id3", ...], + embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2], ...], + metadatas=[{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}, ...], + documents=["doc1", "doc2", "doc3", ...], +) +``` + +{% /tab %} +{% tab label="Javascript" %} + +Any property of records in a collection can be updated using `.updateRecords`. + +```javascript +collection.update( + { + ids: ["id1", "id2", "id3", ...], + embeddings: [[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2], ...], + metadatas: [{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}, ...], + documents: ["doc1", "doc2", "doc3", ...], + }, +) +``` + +{% /tab %} + +{% /tabs %} + +If an `id` is not found in the collection, an error will be logged and the update will be ignored. If `documents` are supplied without corresponding `embeddings`, the embeddings will be recomputed with the collection's embedding function. + +If the supplied `embeddings` are not the same dimension as the collection, an exception will be raised. + +Chroma also supports an `upsert` operation, which updates existing items, or adds them if they don't yet exist. + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +```python +collection.upsert( + ids=["id1", "id2", "id3", ...], + embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2], ...], + metadatas=[{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}, ...], + documents=["doc1", "doc2", "doc3", ...], +) +``` + +{% /tab %} +{% tab label="Javascript" %} + +```javascript +await collection.upsert({ + ids: ["id1", "id2", "id3"], + embeddings: [ + [1.1, 2.3, 3.2], + [4.5, 6.9, 4.4], + [1.1, 2.3, 3.2], + ], + metadatas: [ + { chapter: "3", verse: "16" }, + { chapter: "3", verse: "5" }, + { chapter: "29", verse: "11" }, + ], + documents: ["doc1", "doc2", "doc3"], +}); +``` + +{% /tab %} + +{% /tabs %} + +If an `id` is not present in the collection, the corresponding items will be created as per `add`. Items with existing `id`s will be updated as per `update`. + +### Deleting data from a collection + +Chroma supports deleting items from a collection by `id` using `.delete`. The embeddings, documents, and metadata associated with each item will be deleted. +⚠️ Naturally, this is a destructive operation, and cannot be undone. + +{% tabs group="code-lang" hideTabs=true %} +{% tab label="Python" %} + +```python +collection.delete( + ids=["id1", "id2", "id3",...], + where={"chapter": "20"} +) +``` + +{% /tab %} +{% tab label="Javascript" %} + +```javascript +await collection.delete({ + ids: ["id1", "id2", "id3",...], //ids + where: {"chapter": "20"} //where +}) +``` + +{% /tab %} + +{% /tabs %} + +`.delete` also supports the `where` filter. If no `ids` are supplied, it will delete all items in the collection that match the `where` filter. diff --git a/go/pkg/sysdb/coordinator/coordinator.go b/go/pkg/sysdb/coordinator/coordinator.go index b6378e73a1e..e440c0a07c9 100644 --- a/go/pkg/sysdb/coordinator/coordinator.go +++ b/go/pkg/sysdb/coordinator/coordinator.go @@ -103,6 +103,10 @@ func (s *Coordinator) GetCollections(ctx context.Context, collectionID types.Uni return s.catalog.GetCollections(ctx, collectionID, collectionName, tenantID, databaseName, limit, offset) } +func (s *Coordinator) GetCollectionWithSegments(ctx context.Context, collectionID types.UniqueID) (*model.Collection, []*model.Segment, error) { + return s.catalog.GetCollectionWithSegments(ctx, collectionID) +} + func (s *Coordinator) CheckCollection(ctx context.Context, collectionID types.UniqueID) (bool, error) { return s.catalog.CheckCollection(ctx, collectionID) } diff --git a/go/pkg/sysdb/coordinator/coordinator_test.go b/go/pkg/sysdb/coordinator/coordinator_test.go index 624dfb2767d..16d4bdb42a7 100644 --- a/go/pkg/sysdb/coordinator/coordinator_test.go +++ b/go/pkg/sysdb/coordinator/coordinator_test.go @@ -306,6 +306,21 @@ func (suite *APIsTestSuite) TestCreateCollectionAndSegments() { suite.Equal(segment.ID, segmentResult[0].ID) } + // The same information should be returned by the GetCollectionWithSegments endpoint + collection, collection_segments, error := suite.coordinator.GetCollectionWithSegments(ctx, newCollection.ID) + suite.NoError(error) + suite.Equal(newCollection.ID, collection.ID) + suite.Equal(newCollection.Name, collection.Name) + expected_ids, actual_ids := []types.UniqueID{}, []types.UniqueID{} + for _, segment := range segments { + expected_ids = append(expected_ids, segment.ID) + } + for _, segment := range collection_segments { + suite.Equal(collection.ID, segment.CollectionID) + actual_ids = append(actual_ids, segment.ID) + } + suite.ElementsMatch(expected_ids, actual_ids) + // Attempt to create a duplicate collection (should fail) _, _, err = suite.coordinator.CreateCollectionAndSegments(ctx, newCollection, segments) suite.Error(err) diff --git a/go/pkg/sysdb/coordinator/table_catalog.go b/go/pkg/sysdb/coordinator/table_catalog.go index 684376f5a8c..94d8ef2cd2c 100644 --- a/go/pkg/sysdb/coordinator/table_catalog.go +++ b/go/pkg/sysdb/coordinator/table_catalog.go @@ -338,6 +338,43 @@ func (tc *Catalog) GetCollections(ctx context.Context, collectionID types.Unique return collections, nil } +func (tc *Catalog) GetCollectionWithSegments(ctx context.Context, collectionID types.UniqueID) (*model.Collection, []*model.Segment, error) { + tracer := otel.Tracer + if tracer != nil { + _, span := tracer.Start(ctx, "Catalog.GetCollections") + defer span.End() + } + + var collection *model.Collection + var segments []*model.Segment + + err := tc.txImpl.Transaction(ctx, func(txCtx context.Context) error { + collections, e := tc.GetCollections(ctx, collectionID, nil, "", "", nil, nil) + if e != nil { + return e + } + if len(collections) == 0 { + return common.ErrCollectionNotFound + } + if len(collections) > 1 { + return common.ErrCollectionUniqueConstraintViolation + } + collection = collections[0] + + segments, e = tc.GetSegments(ctx, types.NilUniqueID(), nil, nil, collectionID) + if e != nil { + return e + } + + return nil + }) + if err != nil { + return nil, nil, err + } + + return collection, segments, nil +} + func (tc *Catalog) DeleteCollection(ctx context.Context, deleteCollection *model.DeleteCollection, softDelete bool) error { if softDelete { return tc.softDeleteCollection(ctx, deleteCollection) diff --git a/go/pkg/sysdb/grpc/collection_service.go b/go/pkg/sysdb/grpc/collection_service.go index 89d92a6283f..7d7abd69647 100644 --- a/go/pkg/sysdb/grpc/collection_service.go +++ b/go/pkg/sysdb/grpc/collection_service.go @@ -3,6 +3,7 @@ package grpc import ( "context" "encoding/json" + "fmt" "github.com/chroma-core/chroma/go/pkg/grpcutils" @@ -169,6 +170,50 @@ func (s *Server) CheckCollections(ctx context.Context, req *coordinatorpb.CheckC return res, nil } +func (s *Server) GetCollectionWithSegments(ctx context.Context, req *coordinatorpb.GetCollectionWithSegmentsRequest) (*coordinatorpb.GetCollectionWithSegmentsResponse, error) { + collectionID := req.Id + + res := &coordinatorpb.GetCollectionWithSegmentsResponse{} + + parsedCollectionID, err := types.ToUniqueID(&collectionID) + if err != nil { + log.Error("GetCollectionWithSegments failed. collection id format error", zap.Error(err), zap.String("collection_id", collectionID)) + return res, grpcutils.BuildInternalGrpcError(err.Error()) + } + + collection, segments, err := s.coordinator.GetCollectionWithSegments(ctx, parsedCollectionID) + if err != nil { + log.Error("GetCollectionWithSegments failed. ", zap.Error(err), zap.String("collection_id", collectionID)) + return res, grpcutils.BuildInternalGrpcError(err.Error()) + } + + res.Collection = convertCollectionToProto(collection) + segmentpbList := make([]*coordinatorpb.Segment, 0, len(segments)) + scopeToSegmentMap := map[coordinatorpb.SegmentScope]*coordinatorpb.Segment{} + for _, segment := range segments { + segmentpb := convertSegmentToProto(segment) + scopeToSegmentMap[segmentpb.GetScope()] = segmentpb + segmentpbList = append(segmentpbList, segmentpb) + } + + if len(segmentpbList) != 3 { + log.Error("GetCollectionWithSegments failed. Unexpected number of collection segments", zap.String("collection_id", collectionID)) + return res, grpcutils.BuildInternalGrpcError(fmt.Sprintf("Unexpected number of segments for collection %s: %d", collectionID, len(segmentpbList))) + } + + scopes := []coordinatorpb.SegmentScope{coordinatorpb.SegmentScope_METADATA, coordinatorpb.SegmentScope_RECORD, coordinatorpb.SegmentScope_VECTOR} + + for _, scope := range scopes { + if _, exists := scopeToSegmentMap[scope]; !exists { + log.Error("GetCollectionWithSegments failed. Collection segment scope not found", zap.String("collection_id", collectionID), zap.String("missing_scope", scope.String())) + return res, grpcutils.BuildInternalGrpcError(fmt.Sprintf("Missing segment scope for collection %s: %s", collectionID, scope.String())) + } + } + + res.Segments = segmentpbList + return res, nil +} + func (s *Server) DeleteCollection(ctx context.Context, req *coordinatorpb.DeleteCollectionRequest) (*coordinatorpb.DeleteCollectionResponse, error) { collectionID := req.GetId() res := &coordinatorpb.DeleteCollectionResponse{} diff --git a/go/pkg/sysdb/grpc/collection_service_test.go b/go/pkg/sysdb/grpc/collection_service_test.go index a88db3b91d1..a003e59621b 100644 --- a/go/pkg/sysdb/grpc/collection_service_test.go +++ b/go/pkg/sysdb/grpc/collection_service_test.go @@ -81,16 +81,19 @@ func testCollection(t *rapid.T) { var collectionsWithErrors []*coordinatorpb.Collection t.Repeat(map[string]func(*rapid.T){ - "create_collection": func(t *rapid.T) { + "create_get_collection": func(t *rapid.T) { stringValue := generateStringMetadataValue(t) intValue := generateInt64MetadataValue(t) floatValue := generateFloat64MetadataValue(t) getOrCreate := false + collectionId := rapid.StringMatching(`[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}`).Draw(t, "collection_id") + collectionName := rapid.String().Draw(t, "collection_name") + createCollectionRequest := rapid.Custom[*coordinatorpb.CreateCollectionRequest](func(t *rapid.T) *coordinatorpb.CreateCollectionRequest { return &coordinatorpb.CreateCollectionRequest{ - Id: rapid.StringMatching(`[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}`).Draw(t, "collection_id"), - Name: rapid.String().Draw(t, "collection_name"), + Id: collectionId, + Name: collectionName, Metadata: &coordinatorpb.UpdateMetadata{ Metadata: map[string]*coordinatorpb.UpdateMetadataValue{ "string_value": stringValue, @@ -99,6 +102,26 @@ func testCollection(t *rapid.T) { }, }, GetOrCreate: &getOrCreate, + Segments: []*coordinatorpb.Segment{ + { + Id: rapid.StringMatching(`[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}`).Draw(t, "metadata_segment_id"), + Type: "metadata_segment_type", + Scope: coordinatorpb.SegmentScope_METADATA, + Collection: collectionId, + }, + { + Id: rapid.StringMatching(`[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}`).Draw(t, "record_segment_id"), + Type: "record_segment_type", + Scope: coordinatorpb.SegmentScope_RECORD, + Collection: collectionId, + }, + { + Id: rapid.StringMatching(`[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}`).Draw(t, "vector_segment_id"), + Type: "vector_segment_type", + Scope: coordinatorpb.SegmentScope_VECTOR, + Collection: collectionId, + }, + }, } }).Draw(t, "create_collection_request") @@ -114,29 +137,57 @@ func testCollection(t *rapid.T) { } } - getCollectionsRequest := coordinatorpb.GetCollectionsRequest{ - Id: &createCollectionRequest.Id, - } if err == nil { + getCollectionsRequest := coordinatorpb.GetCollectionsRequest{ + Id: &createCollectionRequest.Id, + } // verify the correctness - GetCollectionsResponse, err := s.GetCollections(ctx, &getCollectionsRequest) + getCollectionsResponse, err := s.GetCollections(ctx, &getCollectionsRequest) if err != nil { t.Fatalf("error getting collections: %v", err) } - collectionList := GetCollectionsResponse.GetCollections() + collectionList := getCollectionsResponse.GetCollections() if len(collectionList) != 1 { - t.Fatalf("More than 1 collection with the same collection id") + t.Fatalf("there should be exactly one matching collection given the collection id") + } + if collectionList[0].Id != createCollectionRequest.Id { + t.Fatalf("collection id mismatch") + } + + getCollectionWithSegmentsRequest := coordinatorpb.GetCollectionWithSegmentsRequest{ + Id: createCollectionRequest.Id, + } + + getCollectionWithSegmentsResponse, err := s.GetCollectionWithSegments(ctx, &getCollectionWithSegmentsRequest) + if err != nil { + t.Fatalf("error getting collection with segments: %v", err) + } + + if getCollectionWithSegmentsResponse.Collection.Id != res.Collection.Id { + t.Fatalf("collection id mismatch") + } + + if len(getCollectionWithSegmentsResponse.Segments) != 3 { + t.Fatalf("unexpected number of segments in collection: %v", getCollectionWithSegmentsResponse.Segments) } - for _, collection := range collectionList { - if collection.Id != createCollectionRequest.Id { - t.Fatalf("collection id is the right value") + + scopeToSegmentMap := map[coordinatorpb.SegmentScope]*coordinatorpb.Segment{} + for _, segment := range getCollectionWithSegmentsResponse.Segments { + if segment.Collection != res.Collection.Id { + t.Fatalf("invalid collection id in segment") + } + scopeToSegmentMap[segment.GetScope()] = segment + } + scopes := []coordinatorpb.SegmentScope{coordinatorpb.SegmentScope_METADATA, coordinatorpb.SegmentScope_RECORD, coordinatorpb.SegmentScope_VECTOR} + for _, scope := range scopes { + if _, exists := scopeToSegmentMap[scope]; !exists { + t.Fatalf("collection segment scope not found: %s", scope.String()) } } + state = append(state, res.Collection) } }, - "get_collections": func(t *rapid.T) { - }, }) } diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index 1e41a474046..564b2e501b2 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -143,6 +143,15 @@ message GetCollectionsResponse { reserved "status"; } +message GetCollectionWithSegmentsRequest { + string id = 1; +} + +message GetCollectionWithSegmentsResponse { + Collection collection = 1; + repeated Segment segments = 2; +} + message CheckCollectionsRequest { repeated string collection_ids = 1; } @@ -219,6 +228,7 @@ service SysDB { rpc CreateCollection(CreateCollectionRequest) returns (CreateCollectionResponse) {} rpc DeleteCollection(DeleteCollectionRequest) returns (DeleteCollectionResponse) {} rpc GetCollections(GetCollectionsRequest) returns (GetCollectionsResponse) {} + rpc GetCollectionWithSegments(GetCollectionWithSegmentsRequest) returns (GetCollectionWithSegmentsResponse) {} rpc CheckCollections(CheckCollectionsRequest) returns (CheckCollectionsResponse) {} rpc UpdateCollection(UpdateCollectionRequest) returns (UpdateCollectionResponse) {} rpc ResetState(google.protobuf.Empty) returns (ResetStateResponse) {} diff --git a/idl/chromadb/proto/query_executor.proto b/idl/chromadb/proto/query_executor.proto index 0149767217c..f434070c9cd 100644 --- a/idl/chromadb/proto/query_executor.proto +++ b/idl/chromadb/proto/query_executor.proto @@ -6,9 +6,11 @@ import "chromadb/proto/chroma.proto"; message ScanOperator { Collection collection = 1; - string knn_id = 2; - string metadata_id = 3; - string record_id = 4; + // Reserve for deprecated fields + reserved 2, 3, 4; + Segment knn = 5; + Segment metadata = 6; + Segment record = 7; } message FilterOperator { diff --git a/k8s/test/minio.yaml b/k8s/test/minio.yaml index ab5037ea59f..6bde69eb009 100644 --- a/k8s/test/minio.yaml +++ b/k8s/test/minio.yaml @@ -23,6 +23,7 @@ spec: args: - server - /storage + - "--console-address=:9005" # Fixed port for MinIO Console env: - name: MINIO_ACCESS_KEY value: "minio" diff --git a/pyproject.toml b/pyproject.toml index 9fb9759dabd..fdc51b199cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ authors = [ ] description = "Chroma." readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", @@ -28,7 +28,7 @@ dependencies = [ 'opentelemetry-exporter-otlp-proto-grpc>=1.2.0', 'opentelemetry-instrumentation-fastapi>=0.41b0', 'opentelemetry-sdk>=1.2.0', - 'tokenizers >= 0.13.2, <= 0.20.3', + 'tokenizers >= 0.13.2', 'pypika >= 0.48.9', 'tqdm >= 4.65.0', 'overrides >= 7.3.1', @@ -49,7 +49,7 @@ dependencies = [ [tool.black] line-length = 88 required-version = "23.3.0" # Black will refuse to run if it's not this version. -target-version = ['py38', 'py39', 'py310', 'py311'] +target-version = ['py39', 'py310', 'py311'] [tool.pytest.ini_options] pythonpath = ["."] diff --git a/requirements.txt b/requirements.txt index 19b079af0eb..422c8062ac5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ opentelemetry-api>=1.2.0 opentelemetry-exporter-otlp-proto-grpc>=1.24.0 opentelemetry-instrumentation-fastapi>=0.41b0 opentelemetry-sdk>=1.2.0 -orjson>=3.9.12, < 3.10.6 # 3.10.7 is currently missing a wheel for x86 glibc +orjson>=3.9.12 overrides>=7.3.1 posthog>=2.4.0 pydantic>=1.9 @@ -21,7 +21,7 @@ pypika>=0.48.9 PyYAML>=6.0.0 rich>=10.11.0 tenacity>=8.2.3 -tokenizers>=0.13.2,<=0.20.3 +tokenizers>=0.13.2 tqdm>=4.65.0 typer>=0.9.0 typing_extensions>=4.5.0 diff --git a/rust/benchmark/src/datasets/gist.rs b/rust/benchmark/src/datasets/gist.rs index d511f298d79..ad26934e50b 100644 --- a/rust/benchmark/src/datasets/gist.rs +++ b/rust/benchmark/src/datasets/gist.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use anyhow::Ok; +use anyhow::Result; use tokio::io::{AsyncReadExt, BufReader}; use super::{ @@ -20,7 +20,7 @@ impl RecordDataset for GistDataset { const DISPLAY_NAME: &'static str = "Gist"; const NAME: &'static str = "gist"; - async fn init() -> anyhow::Result { + async fn init() -> Result { // TODO(Sanket): Download file if it doesn't exist. // move file from downloads to cached path. let current_path = "/Users/sanketkedia/Downloads/siftsmall/siftsmall_base.fvecs"; diff --git a/rust/benchmark/src/datasets/mod.rs b/rust/benchmark/src/datasets/mod.rs index 60d5e35e486..3ceef402298 100644 --- a/rust/benchmark/src/datasets/mod.rs +++ b/rust/benchmark/src/datasets/mod.rs @@ -4,4 +4,5 @@ pub mod util; pub mod gist; pub mod ms_marco_queries; pub mod scidocs; +pub mod sift; pub mod wikipedia; diff --git a/rust/benchmark/src/datasets/sift.rs b/rust/benchmark/src/datasets/sift.rs new file mode 100644 index 00000000000..d55fe1eaa04 --- /dev/null +++ b/rust/benchmark/src/datasets/sift.rs @@ -0,0 +1,197 @@ +use std::{ + io::SeekFrom, + ops::{Bound, RangeBounds}, +}; + +use anyhow::{anyhow, Ok, Result}; +use tokio::{ + fs::File, + io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt, BufReader}, +}; + +use super::util::get_or_populate_cached_dataset_file; + +pub struct Sift1MData { + pub base: BufReader, + pub query: BufReader, + pub ground: BufReader, +} + +impl Sift1MData { + pub async fn init() -> Result { + let base = get_or_populate_cached_dataset_file( + "sift1m", + "base.fvecs", + None, + |mut writer| async move { + let client = reqwest::Client::new(); + let response = client + .get( + "https://huggingface.co/datasets/qbo-odp/sift1m/resolve/main/sift_base.fvecs", + ) + .send() + .await?; + + if !response.status().is_success() { + return Err(anyhow!( + "Failed to download Sift1M base data, got status code {}", + response.status() + )); + } + + writer.write_all(&response.bytes().await?).await?; + + Ok(()) + }, + ).await?; + let query = get_or_populate_cached_dataset_file( + "sift1m", + "query.fvecs", + None, + |mut writer| async move { + let client = reqwest::Client::new(); + let response = client + .get( + "https://huggingface.co/datasets/qbo-odp/sift1m/resolve/main/sift_query.fvecs", + ) + .send() + .await?; + + if !response.status().is_success() { + return Err(anyhow!( + "Failed to download Sift1M query data, got status code {}", + response.status() + )); + } + + writer.write_all(&response.bytes().await?).await?; + + Ok(()) + }, + ).await?; + let ground = get_or_populate_cached_dataset_file( + "sift1m", + "groundtruth.ivecs", + None, + |mut writer| async move { + let client = reqwest::Client::new(); + let response = client + .get( + "https://huggingface.co/datasets/qbo-odp/sift1m/resolve/main/sift_groundtruth.ivecs", + ) + .send() + .await?; + + if !response.status().is_success() { + return Err(anyhow!( + "Failed to download Sift1M ground data, got status code {}", + response.status() + )); + } + + writer.write_all(&response.bytes().await?).await?; + + Ok(()) + }, + ).await?; + Ok(Self { + base: BufReader::new(File::open(base).await?), + query: BufReader::new(File::open(query).await?), + ground: BufReader::new(File::open(ground).await?), + }) + } + + pub fn collection_size() -> usize { + 1000000 + } + + pub fn query_size() -> usize { + 10000 + } + + pub fn dimension() -> usize { + 128 + } + + pub fn k() -> usize { + 100 + } + + pub async fn data_range(&mut self, range: impl RangeBounds) -> Result>> { + let lower_bound = match range.start_bound() { + Bound::Included(include) => *include, + Bound::Excluded(exclude) => exclude + 1, + Bound::Unbounded => 0, + }; + let upper_bound = match range.end_bound() { + Bound::Included(include) => include + 1, + Bound::Excluded(exclude) => *exclude, + Bound::Unbounded => usize::MAX, + } + .min(Self::collection_size()); + + if lower_bound >= upper_bound { + return Ok(Vec::new()); + } + + let vector_size = size_of::() + Self::dimension() * size_of::(); + + let start = SeekFrom::Start((lower_bound * vector_size) as u64); + self.base.seek(start).await?; + let batch_size = upper_bound - lower_bound; + let mut base_bytes = vec![0; batch_size * vector_size]; + self.base.read_exact(&mut base_bytes).await?; + read_raw_vec(&base_bytes, |bytes| { + Ok(f32::from_le_bytes(bytes.try_into()?)) + }) + } + + pub async fn query(&mut self) -> Result, Vec)>> { + let mut query_bytes = Vec::new(); + self.query.read_to_end(&mut query_bytes).await?; + let queries = read_raw_vec(&query_bytes, |bytes| { + Ok(f32::from_le_bytes(bytes.try_into()?)) + })?; + + let mut ground_bytes = Vec::new(); + self.ground.read_to_end(&mut ground_bytes).await?; + let grounds = read_raw_vec(&ground_bytes, |bytes| { + Ok(u32::from_le_bytes(bytes.try_into()?)) + })?; + if queries.len() != grounds.len() { + return Err(anyhow!( + "Queries and grounds count mismatch: {} != {}", + queries.len(), + grounds.len() + )); + } + Ok(queries.into_iter().zip(grounds).collect()) + } +} + +fn read_raw_vec( + raw_bytes: &[u8], + convert_from_bytes: impl Fn(&[u8]) -> Result, +) -> Result>> { + let mut result = Vec::new(); + let mut bytes = raw_bytes; + while !bytes.is_empty() { + let (dimension_bytes, rem_bytes) = bytes.split_at(size_of::()); + let dimension = u32::from_le_bytes(dimension_bytes.try_into()?); + let (embedding_bytes, rem_bytes) = rem_bytes.split_at(dimension as usize * size_of::()); + let embedding = embedding_bytes + .chunks(size_of::()) + .map(&convert_from_bytes) + .collect::>>()?; + if embedding.len() != dimension as usize { + return Err(anyhow!( + "Embedding dimension mismatch: {} != {}", + embedding.len(), + dimension + )); + } + result.push(embedding); + bytes = rem_bytes; + } + Ok(result) +} diff --git a/rust/blockstore/src/arrow/block/types.rs b/rust/blockstore/src/arrow/block/types.rs index 6764591c86f..0a85d16591e 100644 --- a/rust/blockstore/src/arrow/block/types.rs +++ b/rust/blockstore/src/arrow/block/types.rs @@ -277,7 +277,7 @@ impl Block { /// Finds the partition point of the prefix and key. /// Returns the index of the first element that matches the target prefix and key. If no element matches, returns the index at which the target prefix and key could be inserted to maintain sorted order. #[inline] - fn binary_search_prefix_key<'me, K: ArrowReadableKey<'me>>( + pub(crate) fn binary_search_prefix_key<'me, K: ArrowReadableKey<'me>>( &'me self, prefix: &str, key: &K, @@ -418,31 +418,6 @@ impl Block { }) } - /// Get all the values for a given prefix in the block where the key is between the given keys - /// ### Notes - /// - Returns a tuple of (prefix, key, value) - /// - Returns None if the requested index is out of bounds - /// ### Panics - /// - If the underlying data types are not the same as the types specified in the function signature - pub fn get_at_index<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>( - &'me self, - index: usize, - ) -> Option<(&'me str, K, V)> { - if index >= self.data.num_rows() { - return None; - } - let prefix_arr = self - .data - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - let prefix = prefix_arr.value(index); - let key = K::get(self.data.column(1), index); - let value = V::get(self.data.column(2), index); - Some((prefix, key, value)) - } - /* ===== Block Metadata ===== */ diff --git a/rust/blockstore/src/arrow/blockfile.rs b/rust/blockstore/src/arrow/blockfile.rs index 44e2db95587..6be8b5a2fe3 100644 --- a/rust/blockstore/src/arrow/blockfile.rs +++ b/rust/blockstore/src/arrow/blockfile.rs @@ -11,7 +11,6 @@ use crate::arrow::root::CURRENT_VERSION; use crate::arrow::sparse_index::SparseIndexWriter; use crate::key::CompositeKey; use crate::key::KeyWrapper; -use crate::BlockfileError; use chroma_error::ChromaError; use chroma_error::ErrorCodes; use futures::future::join_all; @@ -460,63 +459,6 @@ impl<'me, K: ArrowReadableKey<'me> + Into, V: ArrowReadableValue<'me } } - pub(crate) async fn get_at_index( - &'me self, - index: usize, - ) -> Result<(&'me str, K, V), Box> { - let mut block_offset = 0; - let mut block = None; - let sparse_index_len = self.root.sparse_index.len(); - for i in 0..sparse_index_len { - // This unwrap is safe because we are iterating over the sparse index - // within its len. The sparse index reader is immutable and cannot be modified - let uuid = self - .root - .sparse_index - .data - .forward - .iter() - .nth(i) - .unwrap() - .1 - .id; - block = match self.get_block(uuid).await { - Ok(Some(block)) => Some(block), - Ok(None) => { - tracing::error!("Block with id {:?} not found", uuid); - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - Err(e) => { - return Err(Box::new(e)); - } - }; - match block { - Some(b) => { - if block_offset + b.len() > index { - break; - } - block_offset += b.len(); - } - None => { - tracing::error!("Block id {:?} not found", uuid); - return Err(Box::new(ArrowBlockfileError::BlockNotFound)); - } - } - } - let block = block.unwrap(); - let res = block.get_at_index::<'me, K, V>(index - block_offset); - match res { - Some((prefix, key, value)) => Ok((prefix, key, value)), - _ => { - tracing::error!( - "Value not found at index {:?} for block", - index - block_offset, - ); - Err(Box::new(BlockfileError::NotFoundError)) - } - } - } - // Returns all Arrow records in the specified range. pub(crate) fn get_range_stream<'prefix, PrefixRange, KeyRange>( &'me self, @@ -657,6 +599,56 @@ impl<'me, K: ArrowReadableKey<'me> + Into, V: ArrowReadableValue<'me self.root.id } + /// Returns the number of elements strictly less than the given prefix-key pair in the blockfile + /// In other words, the rank is the position where the given prefix-key pair can be inserted while maintaining the order of the blockfile + pub(crate) async fn rank( + &'me self, + prefix: &'me str, + key: K, + ) -> Result> { + let mut rank = 0; + + // This should be sorted by offset id ranges + let block_ids = self + .root + .sparse_index + .get_block_ids_range(..=prefix, ..=key.clone()); + + // The block that may contain the prefix-key pair + if let Some(last_id) = block_ids.last() { + if self.root.version >= Version::V1_1 { + rank += self + .root + .sparse_index + .data + .forward + .values() + .take(block_ids.len() - 1) + .map(|meta| meta.count) + .sum::() as usize; + } else { + self.load_blocks(&block_ids).await; + for block_id in block_ids.iter().take(block_ids.len() - 1) { + let block = + self.get_block(*block_id) + .await + .map_err(|e| Box::new(e) as Box)? + .ok_or(Box::new(ArrowBlockfileError::BlockNotFound) + as Box)?; + rank += block.len(); + } + } + let last_block = self + .get_block(*last_id) + .await + .map_err(|e| Box::new(e) as Box)? + .ok_or(Box::new(ArrowBlockfileError::BlockNotFound) as Box)?; + rank += last_block.binary_search_prefix_key(prefix, &key); + } + + Ok(rank) + } + /// Check if the blockfile is valid. /// Validates that the sparse index is valid and that no block exceeds the max block size. pub async fn is_valid(&self) -> bool { @@ -1509,7 +1501,7 @@ mod tests { } #[tokio::test] - async fn test_get_at_index() { + async fn test_rank() { let tmp_dir = tempfile::tempdir().unwrap(); let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); let block_cache = new_cache_for_test(); @@ -1541,12 +1533,9 @@ mod tests { .unwrap(); for i in 0..n { - let expected_key = format!("{:04}", i); - let expected_value = vec![i]; - let res = reader.get_at_index(i as usize).await.unwrap(); - assert_eq!(res.0, "key"); - assert_eq!(res.1, expected_key); - assert_eq!(res.2, expected_value); + let rank_key = format!("{:04}", i); + let rank = reader.rank("key", &rank_key).await.unwrap(); + assert_eq!(rank, i as usize); } } diff --git a/rust/blockstore/src/arrow/sparse_index.rs b/rust/blockstore/src/arrow/sparse_index.rs index 3d4124ab447..efb89b1d1db 100644 --- a/rust/blockstore/src/arrow/sparse_index.rs +++ b/rust/blockstore/src/arrow/sparse_index.rs @@ -311,6 +311,8 @@ impl SparseIndexReader { } /// Get the number of keys in the sparse index + /// Used in unit test + #[allow(dead_code)] pub(super) fn len(&self) -> usize { self.data.forward.len() } diff --git a/rust/blockstore/src/memory/reader_writer.rs b/rust/blockstore/src/memory/reader_writer.rs index 82f0520a8a8..282957dae3c 100644 --- a/rust/blockstore/src/memory/reader_writer.rs +++ b/rust/blockstore/src/memory/reader_writer.rs @@ -123,18 +123,6 @@ impl< .map(|(key, value)| (K::try_from(&key.key).unwrap(), value))) } - pub(crate) fn get_at_index( - &'storage self, - index: usize, - ) -> Result<(&'storage str, K, V), Box> { - let res = V::get_at_index(&self.storage, index); - let (key, value) = match res { - Some((key, value)) => (key, value), - None => return Err(Box::new(BlockfileError::NotFoundError)), - }; - Ok((key.prefix.as_str(), K::try_from(&key.key).unwrap(), value)) - } - pub(crate) fn count(&self) -> Result> { V::count(&self.storage) } @@ -146,6 +134,10 @@ impl< pub(crate) fn id(&self) -> uuid::Uuid { self.storage.id } + + pub(crate) fn rank(&'storage self, prefix: &'storage str, key: K) -> usize { + V::rank(prefix, key.into(), &self.storage) + } } #[cfg(test)] @@ -860,7 +852,7 @@ mod tests { } #[tokio::test] - async fn test_get_by_index() { + async fn test_rank() { let storage_manager = StorageManager::new(); let writer = MemoryBlockfileWriter::new(storage_manager.clone()); let id = writer.id; @@ -876,13 +868,9 @@ mod tests { let reader: MemoryBlockfileReader<&str, &str> = MemoryBlockfileReader::open(id, storage_manager.clone()); for i in 0..n { - let expected_key = format!("key{:04}", i); - let expected_value = format!("value{:04}", i); - let (prefix, key, value) = - MemoryBlockfileReader::<&str, &str>::get_at_index(&reader, i).unwrap(); - assert_eq!(prefix, "prefix"); - assert_eq!(key, expected_key.as_str()); - assert_eq!(value, expected_value.as_str()); + let rank_key = format!("key{:04}", i); + let rank = MemoryBlockfileReader::<&str, &str>::rank(&reader, "prefix", &rank_key); + assert_eq!(rank, i); } } } diff --git a/rust/blockstore/src/memory/storage.rs b/rust/blockstore/src/memory/storage.rs index 0b0a297b3c5..fd6cd91d263 100644 --- a/rust/blockstore/src/memory/storage.rs +++ b/rust/blockstore/src/memory/storage.rs @@ -29,14 +29,11 @@ pub trait Readable<'referred_data>: Sized { PrefixRange: std::ops::RangeBounds<&'prefix str>, KeyRange: std::ops::RangeBounds; - fn get_at_index( - storage: &'referred_data Storage, - index: usize, - ) -> Option<(&'referred_data CompositeKey, Self)>; - fn count(storage: &Storage) -> Result>; fn contains(prefix: &str, key: KeyWrapper, storage: &'referred_data Storage) -> bool; + + fn rank(prefix: &str, key: KeyWrapper, storage: &'referred_data Storage) -> usize; } impl Writeable for String { @@ -102,17 +99,6 @@ impl<'referred_data> Readable<'referred_data> for &'referred_data str { .collect() } - fn get_at_index( - storage: &'referred_data Storage, - index: usize, - ) -> Option<(&'referred_data CompositeKey, Self)> { - storage - .string_value_storage - .iter() - .nth(index) - .map(|(k, v)| (k, v.as_str())) - } - fn count(storage: &Storage) -> Result> { Ok(storage.string_value_storage.iter().len()) } @@ -126,6 +112,18 @@ impl<'referred_data> Readable<'referred_data> for &'referred_data str { }) .is_some() } + + fn rank(prefix: &str, key: KeyWrapper, storage: &'referred_data Storage) -> usize { + storage + .string_value_storage + .range( + ..CompositeKey { + prefix: prefix.to_string(), + key, + }, + ) + .count() + } } // TODO: remove this and make this all use a unified storage so we don't have two impls @@ -192,17 +190,6 @@ impl<'referred_data> Readable<'referred_data> for &'referred_data [u32] { .collect() } - fn get_at_index( - storage: &'referred_data Storage, - index: usize, - ) -> Option<(&'referred_data CompositeKey, Self)> { - storage - .uint32_array_storage - .iter() - .nth(index) - .map(|(k, v)| (k, v.as_slice())) - } - fn count(storage: &Storage) -> Result> { Ok(storage.uint32_array_storage.iter().len()) } @@ -216,6 +203,18 @@ impl<'referred_data> Readable<'referred_data> for &'referred_data [u32] { }) .is_some() } + + fn rank(prefix: &str, key: KeyWrapper, storage: &'referred_data Storage) -> usize { + storage + .uint32_array_storage + .range( + ..CompositeKey { + prefix: prefix.to_string(), + key, + }, + ) + .count() + } } impl Writeable for RoaringBitmap { @@ -277,17 +276,6 @@ impl<'referred_data> Readable<'referred_data> for RoaringBitmap { .collect() } - fn get_at_index( - storage: &'referred_data Storage, - index: usize, - ) -> Option<(&'referred_data CompositeKey, Self)> { - storage - .roaring_bitmap_storage - .iter() - .nth(index) - .map(|(k, v)| (k, v.clone())) - } - fn count(storage: &Storage) -> Result> { Ok(storage.roaring_bitmap_storage.iter().len()) } @@ -301,6 +289,18 @@ impl<'referred_data> Readable<'referred_data> for RoaringBitmap { }) .is_some() } + + fn rank(prefix: &str, key: KeyWrapper, storage: &'referred_data Storage) -> usize { + storage + .roaring_bitmap_storage + .range( + ..CompositeKey { + prefix: prefix.to_string(), + key, + }, + ) + .count() + } } impl Writeable for f32 { @@ -357,13 +357,6 @@ impl<'referred_data> Readable<'referred_data> for f32 { .collect() } - fn get_at_index( - storage: &'referred_data Storage, - index: usize, - ) -> Option<(&'referred_data CompositeKey, Self)> { - storage.f32_storage.iter().nth(index).map(|(k, v)| (k, *v)) - } - fn count(storage: &Storage) -> Result> { Ok(storage.f32_storage.iter().len()) } @@ -377,6 +370,18 @@ impl<'referred_data> Readable<'referred_data> for f32 { }) .is_some() } + + fn rank(prefix: &str, key: KeyWrapper, storage: &'referred_data Storage) -> usize { + storage + .f32_storage + .range( + ..CompositeKey { + prefix: prefix.to_string(), + key, + }, + ) + .count() + } } impl Writeable for u32 { @@ -433,13 +438,6 @@ impl<'referred_data> Readable<'referred_data> for u32 { .collect() } - fn get_at_index( - storage: &'referred_data Storage, - index: usize, - ) -> Option<(&'referred_data CompositeKey, Self)> { - storage.u32_storage.iter().nth(index).map(|(k, v)| (k, *v)) - } - fn count(storage: &Storage) -> Result> { Ok(storage.u32_storage.iter().len()) } @@ -453,6 +451,18 @@ impl<'referred_data> Readable<'referred_data> for u32 { }) .is_some() } + + fn rank(prefix: &str, key: KeyWrapper, storage: &'referred_data Storage) -> usize { + storage + .u32_storage + .range( + ..CompositeKey { + prefix: prefix.to_string(), + key, + }, + ) + .count() + } } impl Writeable for bool { @@ -513,13 +523,6 @@ impl<'referred_data> Readable<'referred_data> for bool { .collect() } - fn get_at_index( - storage: &'referred_data Storage, - index: usize, - ) -> Option<(&'referred_data CompositeKey, Self)> { - storage.bool_storage.iter().nth(index).map(|(k, v)| (k, *v)) - } - fn count(storage: &Storage) -> Result> { Ok(storage.bool_storage.iter().len()) } @@ -533,6 +536,18 @@ impl<'referred_data> Readable<'referred_data> for bool { }) .is_some() } + + fn rank(prefix: &str, key: KeyWrapper, storage: &'referred_data Storage) -> usize { + storage + .bool_storage + .range( + ..CompositeKey { + prefix: prefix.to_string(), + key, + }, + ) + .count() + } } impl Writeable for &DataRecord<'_> { @@ -647,24 +662,6 @@ impl<'referred_data> Readable<'referred_data> for DataRecord<'referred_data> { .collect() } - fn get_at_index( - storage: &'referred_data Storage, - index: usize, - ) -> Option<(&'referred_data CompositeKey, Self)> { - let (k, v) = storage.data_record_id_storage.iter().nth(index).unwrap(); - let embedding = storage.data_record_embedding_storage.get(k).unwrap(); - let id = v; - Some(( - k, - DataRecord { - id, - embedding, - metadata: None, - document: None, - }, - )) - } - fn count(storage: &Storage) -> Result> { Ok(storage.data_record_id_storage.iter().len()) } @@ -678,6 +675,18 @@ impl<'referred_data> Readable<'referred_data> for DataRecord<'referred_data> { }) .is_some() } + + fn rank(prefix: &str, key: KeyWrapper, storage: &'referred_data Storage) -> usize { + storage + .data_record_id_storage + .range( + ..CompositeKey { + prefix: prefix.to_string(), + key, + }, + ) + .count() + } } impl<'referred_data> Readable<'referred_data> for SpannPostingList<'referred_data> { @@ -697,18 +706,15 @@ impl<'referred_data> Readable<'referred_data> for SpannPostingList<'referred_dat todo!() } - fn get_at_index( - _: &'referred_data Storage, - _: usize, - ) -> Option<(&'referred_data CompositeKey, Self)> { + fn count(_: &Storage) -> Result> { todo!() } - fn count(_: &Storage) -> Result> { + fn contains(_: &str, _: KeyWrapper, _: &'referred_data Storage) -> bool { todo!() } - fn contains(_: &str, _: KeyWrapper, _: &'referred_data Storage) -> bool { + fn rank(_: &str, _: KeyWrapper, _: &'referred_data Storage) -> usize { todo!() } } diff --git a/rust/blockstore/src/types/reader.rs b/rust/blockstore/src/types/reader.rs index 43311d979d4..238537a8523 100644 --- a/rust/blockstore/src/types/reader.rs +++ b/rust/blockstore/src/types/reader.rs @@ -106,16 +106,6 @@ impl< } } - pub async fn get_at_index( - &'referred_data self, - index: usize, - ) -> Result<(&'referred_data str, K, V), Box> { - match self { - BlockfileReader::MemoryBlockfileReader(reader) => reader.get_at_index(index), - BlockfileReader::ArrowBlockfileReader(reader) => reader.get_at_index(index).await, - } - } - pub fn id(&self) -> uuid::Uuid { match self { BlockfileReader::MemoryBlockfileReader(reader) => reader.id(), @@ -131,4 +121,15 @@ impl< } } } + + pub async fn rank( + &'referred_data self, + prefix: &'referred_data str, + key: K, + ) -> Result> { + match self { + BlockfileReader::MemoryBlockfileReader(reader) => Ok(reader.rank(prefix, key)), + BlockfileReader::ArrowBlockfileReader(reader) => reader.rank(prefix, key).await, + } + } } diff --git a/rust/load/Cargo.toml b/rust/load/Cargo.toml index 48e248cce67..71e6912a49b 100644 --- a/rust/load/Cargo.toml +++ b/rust/load/Cargo.toml @@ -24,7 +24,7 @@ opentelemetry_sdk = { workspace = true } # Unlikely to be used in the workspace. axum = "0.7" -chromadb = { git = "https://github.com/rescrv/chromadb-rs", rev = "e364e35c34c660d4e8e862436ea600ddc2f46a1e" } +chromadb = { git = "https://github.com/rescrv/chromadb-rs", rev = "3b2a9c96bf99cd0f9bd4e09ea983df335d6bbf68" } guacamole = { version = "0.9", default-features = false } tower-http = { version = "0.6.2", features = ["trace"] } reqwest = { version = "0.12", features = ["json"] } diff --git a/rust/load/src/bin/chroma-load-start.rs b/rust/load/src/bin/chroma-load-start.rs index c6b54208362..2ddfac9f882 100644 --- a/rust/load/src/bin/chroma-load-start.rs +++ b/rust/load/src/bin/chroma-load-start.rs @@ -14,6 +14,8 @@ struct Args { #[arg(long)] expires: String, #[arg(long)] + delay: Option, + #[arg(long)] data_set: String, #[arg(long)] workload: String, @@ -84,11 +86,20 @@ async fn main() { let args = Args::parse(); let client = reqwest::Client::new(); let throughput = args.throughput(); + let mut workload = Workload::ByName(args.workload); + if let Some(delay) = args.delay { + let delay = humanize_expires(&delay).expect("delay must be humanizable"); + let delay = delay.parse().expect("delay must be a date time"); + workload = Workload::Delay { + after: delay, + wrap: Box::new(workload), + }; + } let req = StartRequest { name: args.name, expires: humanize_expires(&args.expires).unwrap_or(args.expires), data_set: args.data_set, - workload: Workload::ByName(args.workload), + workload, throughput, }; match client @@ -116,10 +127,13 @@ async fn main() { ); } else { eprintln!( - "Failed to start workload on {}: {}", + "Categorically failed to start workload on {}: {}", args.host, resp.status() ); + if let Ok(text) = resp.text().await { + eprintln!("{}", text.trim()); + } } } Err(e) => eprintln!("Failed to start workload on {}: {}", args.host, e), diff --git a/rust/load/src/bin/chroma-load-stop-all.rs b/rust/load/src/bin/chroma-load-stop-all.rs new file mode 100644 index 00000000000..68d5e9dc78c --- /dev/null +++ b/rust/load/src/bin/chroma-load-stop-all.rs @@ -0,0 +1,60 @@ +//! Stop all workloads on the chroma-load server. +//! +//! If you are looking to stop traffic for a SEV, see chroma-load-inhibit. + +use clap::Parser; + +use chroma_load::rest::StopRequest; + +#[derive(Parser, Debug)] +struct Args { + #[arg(long)] + host: String, +} + +#[tokio::main] +async fn main() { + let args = Args::parse(); + let client = reqwest::Client::new(); + match client + .get(format!("{}/", args.host)) + .header(reqwest::header::ACCEPT, "application/json") + .send() + .await + { + Ok(resp) => { + let resp = resp.error_for_status().expect("Failed to get status"); + let resp = resp + .json::() + .await + .expect("Failed to parse status"); + for workload in resp.running { + let req = StopRequest { + uuid: workload.uuid, + }; + match client + .post(format!("{}/stop", args.host)) + .json(&req) + .send() + .await + { + Ok(resp) => { + if resp.status().is_success() { + println!("Stopped workload on {}", args.host); + } else { + eprintln!( + "Failed to stop workload on {}: {}", + args.host, + resp.status() + ); + } + } + Err(e) => eprintln!("Failed to stop workload on {}: {}", args.host, e), + } + } + } + Err(e) => { + eprintln!("Failed to get status: {}", e); + } + } +} diff --git a/rust/load/src/bit_difference.rs b/rust/load/src/bit_difference.rs index 4fcbabc3cd0..8607bc9adf1 100644 --- a/rust/load/src/bit_difference.rs +++ b/rust/load/src/bit_difference.rs @@ -25,14 +25,14 @@ //! Internally, guacamole provides primitives that make it easy to manage the set of seeds to get a //! variety of data sets out of the synthetic data. -use chromadb::v2::collection::{CollectionEntries, GetOptions, QueryOptions}; -use chromadb::v2::ChromaClient; +use chromadb::collection::{CollectionEntries, GetOptions, QueryOptions}; +use chromadb::ChromaClient; use guacamole::combinators::*; use guacamole::{FromGuacamole, Guacamole, Zipf}; use siphasher::sip::SipHasher24; use tracing::Instrument; -use crate::words::WORDS; +use crate::words::MANY_WORDS; use crate::{DataSet, GetQuery, KeySelector, QueryQuery, Skew, UpsertQuery}; const EMBEDDING_BYTES: usize = 128; @@ -101,7 +101,7 @@ impl Document { pub fn embedding(&self) -> Vec { let mut result = vec![]; let words = self.content.split_whitespace().collect::>(); - for word in WORDS.iter() { + for word in MANY_WORDS.iter() { if words.contains(word) { result.push(1.0); } else { @@ -114,7 +114,7 @@ impl Document { impl From<[u8; EMBEDDING_BYTES]> for Document { fn from(embedding: [u8; EMBEDDING_BYTES]) -> Document { - let document = WORDS + let document = MANY_WORDS .iter() .enumerate() .filter_map(|(idx, word)| { @@ -388,7 +388,7 @@ mod tests { #[test] fn constants() { - assert_eq!(EMBEDDING_SIZE, WORDS.len()); + assert_eq!(EMBEDDING_SIZE, MANY_WORDS.len()); } mod synthethic { diff --git a/rust/load/src/data_sets.rs b/rust/load/src/data_sets.rs index 931b813379e..9fbf1364c95 100644 --- a/rust/load/src/data_sets.rs +++ b/rust/load/src/data_sets.rs @@ -1,7 +1,7 @@ use std::sync::Arc; -use chromadb::v2::collection::{GetOptions, QueryOptions}; -use chromadb::v2::ChromaClient; +use chromadb::collection::{GetOptions, QueryOptions}; +use chromadb::ChromaClient; use guacamole::combinators::*; use guacamole::Guacamole; use tracing::Instrument; diff --git a/rust/load/src/lib.rs b/rust/load/src/lib.rs index 388fa594c15..ee508eb0470 100644 --- a/rust/load/src/lib.rs +++ b/rust/load/src/lib.rs @@ -24,8 +24,8 @@ use axum::http::StatusCode; use axum::response::IntoResponse; use axum::routing::{get, post}; use axum::{Json, Router}; -use chromadb::v2::client::{ChromaAuthMethod, ChromaClientOptions, ChromaTokenHeader}; -use chromadb::v2::ChromaClient; +use chromadb::client::{ChromaAuthMethod, ChromaClientOptions, ChromaTokenHeader}; +use chromadb::ChromaClient; use guacamole::combinators::*; use guacamole::{Guacamole, Zipf}; use opentelemetry::global; @@ -113,29 +113,33 @@ pub struct Metrics { /// Instantiate a new Chroma client. This will use the CHROMA_HOST environment variable (or /// http://localhost:8000 when unset) as the argument to [client_for_url]. -pub fn client() -> ChromaClient { +pub async fn client() -> ChromaClient { let url = std::env::var("CHROMA_HOST").unwrap_or_else(|_| "http://localhost:8000".into()); - client_for_url(url) + client_for_url(url).await } /// Create a new Chroma client for the given URL. This will use the CHROMA_TOKEN environment /// variable if set, or no authentication if unset. -pub fn client_for_url(url: String) -> ChromaClient { +pub async fn client_for_url(url: String) -> ChromaClient { if let Ok(auth) = std::env::var("CHROMA_TOKEN") { ChromaClient::new(ChromaClientOptions { - url, + url: Some(url), auth: ChromaAuthMethod::TokenAuth { token: auth, header: ChromaTokenHeader::XChromaToken, }, - database: Some("hf-tiny-stories".to_string()), + database: "hf-tiny-stories".to_string(), }) + .await + .unwrap() } else { ChromaClient::new(ChromaClientOptions { - url, + url: Some(url), auth: ChromaAuthMethod::None, - database: Some("hf-tiny-stories".to_string()), + database: "hf-tiny-stories".to_string(), }) + .await + .unwrap() } } @@ -306,6 +310,9 @@ pub enum WhereMixin { /// A raw metadata query simply copies the provided filter spec. #[serde(rename = "query")] Constant(serde_json::Value), + /// Search for a word from the provided set of words with skew. + #[serde(rename = "fts")] + FullTextSearch(Skew), /// The tiny stories workload. The way these collections were setup, there are three fields /// each of integer, float, and string. The integer fields are named i1, i2, and i3. The /// float fields are named f1, f2, and f3. The string fields are named s1, s2, and s3. @@ -325,6 +332,17 @@ impl WhereMixin { pub fn to_json(&self, guac: &mut Guacamole) -> serde_json::Value { match self { Self::Constant(query) => query.clone(), + Self::FullTextSearch(skew) => { + const WORDS: &[&str] = words::FEW_WORDS; + let word = match skew { + Skew::Uniform => WORDS[uniform(0, WORDS.len() as u64)(guac) as usize], + Skew::Zipf { theta } => { + let z = Zipf::from_alpha(WORDS.len() as u64, *theta); + WORDS[z.next(guac) as usize] + } + }; + serde_json::json!({"$contains": word.to_string()}) + } Self::TinyStories(mixin) => mixin.to_json(guac), Self::Select(select) => { let scale: f64 = any(guac); @@ -478,7 +496,7 @@ impl Workload { if let Some(workload) = workloads.get(name) { *self = workload.clone(); } else { - return Err(Error::InvalidRequest(format!("workload not found: {name}"))); + return Err(Error::NotFound(format!("workload not found: {name}"))); } } Workload::Get(_) => {} @@ -1093,7 +1111,7 @@ impl LoadService { inhibit: Arc, spec: RunningWorkload, ) { - let client = Arc::new(client()); + let client = Arc::new(client().await); let mut guac = Guacamole::new(spec.expires.timestamp_millis() as u64); let mut next_op = Instant::now(); let (tx, mut rx) = tokio::sync::mpsc::channel(1000); @@ -1412,7 +1430,6 @@ pub async fn entrypoint() { "http_request", method = ?request.method(), matched_path, - some_other_field = tracing::field::Empty, ) }), ) diff --git a/rust/load/src/opentelemetry_config.rs b/rust/load/src/opentelemetry_config.rs index 708a19220cf..26fb68613a9 100644 --- a/rust/load/src/opentelemetry_config.rs +++ b/rust/load/src/opentelemetry_config.rs @@ -132,32 +132,7 @@ pub(crate) fn init_otel_tracing(service_name: &String, otel_endpoint: &String) { .with_filter(tracing_subscriber::filter::LevelFilter::INFO); // global filter layer. Don't filter anything at above trace at the global layer for chroma. // And enable errors for every other library. - let global_layer = EnvFilter::new(std::env::var("RUST_LOG").unwrap_or_else(|_| { - "info,".to_string() - + &vec![ - "chroma", - "chroma-blockstore", - "chroma-config", - "chroma-cache", - "chroma-distance", - "chroma-error", - "chroma-index", - "chroma-load", - "chroma-storage", - "chroma-test", - "chroma-types", - "compaction_service", - "distance_metrics", - "full_text", - "metadata_filtering", - "query_service", - "worker", - ] - .into_iter() - .map(|s| s.to_string() + "=trace") - .collect::>() - .join(",") - })); + let global_layer = EnvFilter::new(std::env::var("RUST_LOG").unwrap_or("error".to_string())); // Create subscriber. let subscriber = tracing_subscriber::registry() diff --git a/rust/load/src/words.rs b/rust/load/src/words.rs index f77f185c7b1..77f9023c224 100644 --- a/rust/load/src/words.rs +++ b/rust/load/src/words.rs @@ -1,4 +1,4 @@ -pub const WORDS: &[&str] = &[ +pub const MANY_WORDS: &[&str] = &[ "man’s", "sought", "touch", @@ -1024,3 +1024,1006 @@ pub const WORDS: &[&str] = &[ "and", "the", ]; + +pub const FEW_WORDS: &[&str] = &[ + "the", + "of", + "to", + "and", + "a", + "in", + "is", + "it", + "you", + "that", + "he", + "was", + "for", + "on", + "are", + "with", + "as", + "I", + "his", + "they", + "be", + "at", + "one", + "have", + "this", + "from", + "or", + "had", + "by", + "not", + "word", + "but", + "what", + "some", + "we", + "can", + "out", + "other", + "were", + "all", + "there", + "when", + "up", + "use", + "your", + "how", + "said", + "an", + "each", + "she", + "which", + "do", + "their", + "time", + "if", + "will", + "way", + "about", + "many", + "then", + "them", + "write", + "would", + "like", + "so", + "these", + "her", + "long", + "make", + "thing", + "see", + "him", + "two", + "has", + "look", + "more", + "day", + "could", + "go", + "come", + "did", + "number", + "sound", + "no", + "most", + "people", + "my", + "over", + "know", + "water", + "than", + "call", + "first", + "who", + "may", + "down", + "side", + "been", + "now", + "find", + "any", + "new", + "work", + "part", + "take", + "get", + "place", + "made", + "live", + "where", + "after", + "back", + "little", + "only", + "round", + "man", + "year", + "came", + "show", + "every", + "good", + "me", + "give", + "our", + "under", + "name", + "very", + "through", + "just", + "form", + "sentence", + "great", + "think", + "say", + "help", + "low", + "line", + "differ", + "turn", + "cause", + "much", + "mean", + "before", + "move", + "right", + "boy", + "old", + "too", + "same", + "tell", + "does", + "set", + "three", + "want", + "air", + "well", + "also", + "play", + "small", + "end", + "put", + "home", + "read", + "hand", + "port", + "large", + "spell", + "add", + "even", + "land", + "here", + "must", + "big", + "high", + "such", + "follow", + "act", + "why", + "ask", + "men", + "change", + "went", + "light", + "kind", + "off", + "need", + "house", + "picture", + "try", + "us", + "again", + "animal", + "point", + "mother", + "world", + "near", + "build", + "self", + "earth", + "father", + "head", + "stand", + "own", + "page", + "should", + "country", + "found", + "answer", + "school", + "grow", + "study", + "still", + "learn", + "plant", + "cover", + "food", + "sun", + "four", + "between", + "state", + "keep", + "eye", + "never", + "last", + "let", + "thought", + "city", + "tree", + "cross", + "farm", + "hard", + "start", + "might", + "story", + "saw", + "far", + "sea", + "draw", + "left", + "late", + "run", + "don't", + "while", + "press", + "close", + "night", + "real", + "life", + "few", + "north", + "open", + "seem", + "together", + "next", + "white", + "children", + "begin", + "got", + "walk", + "example", + "ease", + "paper", + "group", + "always", + "music", + "those", + "both", + "mark", + "often", + "letter", + "until", + "mile", + "river", + "car", + "feet", + "care", + "second", + "book", + "carry", + "took", + "science", + "eat", + "room", + "friend", + "began", + "idea", + "fish", + "mountain", + "stop", + "once", + "base", + "hear", + "horse", + "cut", + "sure", + "watch", + "color", + "face", + "wood", + "main", + "enough", + "plain", + "girl", + "usual", + "young", + "ready", + "above", + "ever", + "red", + "list", + "though", + "feel", + "talk", + "bird", + "soon", + "body", + "dog", + "family", + "direct", + "pose", + "leave", + "song", + "measure", + "door", + "product", + "black", + "short", + "numeral", + "class", + "wind", + "question", + "happen", + "complete", + "ship", + "area", + "half", + "rock", + "order", + "fire", + "south", + "problem", + "piece", + "told", + "knew", + "pass", + "since", + "top", + "whole", + "king", + "space", + "heard", + "best", + "hour", + "better", + "true", + "during", + "hundred", + "five", + "remember", + "step", + "early", + "hold", + "west", + "ground", + "interest", + "reach", + "fast", + "verb", + "sing", + "listen", + "six", + "table", + "travel", + "less", + "morning", + "ten", + "simple", + "several", + "vowel", + "toward", + "war", + "lay", + "against", + "pattern", + "slow", + "center", + "love", + "person", + "money", + "serve", + "appear", + "road", + "map", + "rain", + "rule", + "govern", + "pull", + "cold", + "notice", + "voice", + "unit", + "power", + "town", + "fine", + "certain", + "fly", + "fall", + "lead", + "cry", + "dark", + "machine", + "note", + "wait", + "plan", + "figure", + "star", + "box", + "noun", + "field", + "rest", + "correct", + "able", + "pound", + "done", + "beauty", + "drive", + "stood", + "contain", + "front", + "teach", + "week", + "final", + "gave", + "green", + "oh", + "quick", + "develop", + "ocean", + "warm", + "free", + "minute", + "strong", + "special", + "mind", + "behind", + "clear", + "tail", + "produce", + "fact", + "street", + "inch", + "multiply", + "nothing", + "course", + "stay", + "wheel", + "full", + "force", + "blue", + "object", + "decide", + "surface", + "deep", + "moon", + "island", + "foot", + "system", + "busy", + "test", + "record", + "boat", + "common", + "gold", + "possible", + "plane", + "stead", + "dry", + "wonder", + "laugh", + "thousand", + "ago", + "ran", + "check", + "game", + "shape", + "equate", + "hot", + "miss", + "brought", + "heat", + "snow", + "tire", + "bring", + "yes", + "distant", + "fill", + "east", + "paint", + "language", + "among", + "grand", + "ball", + "yet", + "wave", + "drop", + "heart", + "am", + "present", + "heavy", + "dance", + "engine", + "position", + "arm", + "wide", + "sail", + "material", + "size", + "vary", + "settle", + "speak", + "weight", + "general", + "ice", + "matter", + "circle", + "pair", + "include", + "divide", + "syllable", + "felt", + "perhaps", + "pick", + "sudden", + "count", + "square", + "reason", + "length", + "represent", + "art", + "subject", + "region", + "energy", + "hunt", + "probable", + "bed", + "brother", + "egg", + "ride", + "cell", + "believe", + "fraction", + "forest", + "sit", + "race", + "window", + "store", + "summer", + "train", + "sleep", + "prove", + "lone", + "leg", + "exercise", + "wall", + "catch", + "mount", + "wish", + "sky", + "board", + "joy", + "winter", + "sat", + "written", + "wild", + "instrument", + "kept", + "glass", + "grass", + "cow", + "job", + "edge", + "sign", + "visit", + "past", + "soft", + "fun", + "bright", + "gas", + "weather", + "month", + "million", + "bear", + "finish", + "happy", + "hope", + "flower", + "clothe", + "strange", + "gone", + "jump", + "baby", + "eight", + "village", + "meet", + "root", + "buy", + "raise", + "solve", + "metal", + "whether", + "push", + "seven", + "paragraph", + "third", + "shall", + "held", + "hair", + "describe", + "cook", + "floor", + "either", + "result", + "burn", + "hill", + "safe", + "cat", + "century", + "consider", + "type", + "law", + "bit", + "coast", + "copy", + "phrase", + "silent", + "tall", + "sand", + "soil", + "roll", + "temperature", + "finger", + "industry", + "value", + "fight", + "lie", + "beat", + "excite", + "natural", + "view", + "sense", + "ear", + "else", + "quite", + "broke", + "case", + "middle", + "kill", + "son", + "lake", + "moment", + "scale", + "loud", + "spring", + "observe", + "child", + "straight", + "consonant", + "nation", + "dictionary", + "milk", + "speed", + "method", + "organ", + "pay", + "age", + "section", + "dress", + "cloud", + "surprise", + "quiet", + "stone", + "tiny", + "climb", + "cool", + "design", + "poor", + "lot", + "experiment", + "bottom", + "key", + "iron", + "single", + "stick", + "flat", + "twenty", + "skin", + "smile", + "crease", + "hole", + "trade", + "melody", + "trip", + "office", + "receive", + "row", + "mouth", + "exact", + "symbol", + "die", + "least", + "trouble", + "shout", + "except", + "wrote", + "seed", + "tone", + "join", + "suggest", + "clean", + "break", + "lady", + "yard", + "rise", + "bad", + "blow", + "oil", + "blood", + "touch", + "grew", + "cent", + "mix", + "team", + "wire", + "cost", + "lost", + "brown", + "wear", + "garden", + "equal", + "sent", + "choose", + "fell", + "fit", + "flow", + "fair", + "bank", + "collect", + "save", + "control", + "decimal", + "gentle", + "woman", + "captain", + "practice", + "separate", + "difficult", + "doctor", + "please", + "protect", + "noon", + "whose", + "locate", + "ring", + "character", + "insect", + "caught", + "period", + "indicate", + "radio", + "spoke", + "atom", + "human", + "history", + "effect", + "electric", + "expect", + "crop", + "modern", + "element", + "hit", + "student", + "corner", + "party", + "supply", + "bone", + "rail", + "imagine", + "provide", + "agree", + "thus", + "capital", + "won't", + "chair", + "danger", + "fruit", + "rich", + "thick", + "soldier", + "process", + "operate", + "guess", + "necessary", + "sharp", + "wing", + "create", + "neighbor", + "wash", + "bat", + "rather", + "crowd", + "corn", + "compare", + "poem", + "string", + "bell", + "depend", + "meat", + "rub", + "tube", + "famous", + "dollar", + "stream", + "fear", + "sight", + "thin", + "triangle", + "planet", + "hurry", + "chief", + "colony", + "clock", + "mine", + "tie", + "enter", + "major", + "fresh", + "search", + "send", + "yellow", + "gun", + "allow", + "print", + "dead", + "spot", + "desert", + "suit", + "current", + "lift", + "rose", + "continue", + "block", + "chart", + "hat", + "sell", + "success", + "company", + "subtract", + "event", + "particular", + "deal", + "swim", + "term", + "opposite", + "wife", + "shoe", + "shoulder", + "spread", + "arrange", + "camp", + "invent", + "cotton", + "born", + "determine", + "quart", + "nine", + "truck", + "noise", + "level", + "chance", + "gather", + "shop", + "stretch", + "throw", + "shine", + "property", + "column", + "molecule", + "select", + "wrong", + "gray", + "repeat", + "require", + "broad", + "prepare", + "salt", + "nose", + "plural", + "anger", + "claim", + "continent", + "oxygen", + "sugar", + "death", + "pretty", + "skill", + "women", + "season", + "solution", + "magnet", + "silver", + "thank", + "branch", + "match", + "suffix", + "especially", + "fig", + "afraid", + "huge", + "sister", + "steel", + "discuss", + "forward", + "similar", + "guide", + "experience", + "score", + "apple", + "bought", + "led", + "pitch", + "coat", + "mass", + "card", + "band", + "rope", + "slip", + "win", + "dream", + "evening", + "condition", + "feed", + "tool", + "total", + "basic", + "smell", + "valley", + "nor", + "double", + "seat", + "arrive", + "master", + "track", + "parent", + "shore", + "division", + "sheet", + "substance", + "favor", + "connect", + "post", + "spend", + "chord", + "fat", + "glad", + "original", + "share", + "station", + "dad", + "bread", + "charge", + "proper", + "bar", + "offer", + "segment", + "slave", + "duck", + "instant", + "market", + "degree", + "populate", + "chick", + "dear", + "enemy", + "reply", + "drink", + "occur", + "support", + "speech", + "nature", + "range", + "steam", + "motion", + "path", + "liquid", + "log", + "meant", + "quotient", + "teeth", + "shell", + "neck", +]; diff --git a/rust/load/src/workloads.rs b/rust/load/src/workloads.rs index f51c199bcc7..06aca191c8f 100644 --- a/rust/load/src/workloads.rs +++ b/rust/load/src/workloads.rs @@ -22,9 +22,7 @@ pub fn all_workloads() -> HashMap { skew: Skew::Zipf { theta: 0.999 }, limit: Distribution::Constant(10), metadata: None, - document: Some(WhereMixin::Constant( - serde_json::json!({"$contains": "the"}), - )), + document: Some(WhereMixin::FullTextSearch(Skew::Zipf { theta: 0.99 })), }), ), ( @@ -56,9 +54,7 @@ pub fn all_workloads() -> HashMap { skew: Skew::Zipf { theta: 0.999 }, limit: Distribution::Constant(10), metadata: None, - document: Some(WhereMixin::Constant( - serde_json::json!({"$contains": "the"}), - )), + document: Some(WhereMixin::FullTextSearch(Skew::Zipf { theta: 0.99 })), }), ), ( @@ -83,9 +79,7 @@ pub fn all_workloads() -> HashMap { skew: Skew::Zipf { theta: 0.999 }, limit: Distribution::Constant(10), metadata: None, - document: Some(WhereMixin::Constant( - serde_json::json!({"$contains": "the"}), - )), + document: Some(WhereMixin::FullTextSearch(Skew::Zipf { theta: 0.99 })), }), ), ( diff --git a/rust/types/src/collection.rs b/rust/types/src/collection.rs index 23e55a0a588..9c7bd8914f2 100644 --- a/rust/types/src/collection.rs +++ b/rust/types/src/collection.rs @@ -1,5 +1,5 @@ use super::{Metadata, MetadataValueConversionError}; -use crate::chroma_proto; +use crate::{chroma_proto, ConversionError, Segment}; use chroma_error::{ChromaError, ErrorCodes}; use thiserror::Error; use uuid::Uuid; @@ -89,6 +89,43 @@ impl TryFrom for Collection { } } +#[derive(Clone, Debug)] +pub struct CollectionAndSegments { + pub collection: Collection, + pub metadata_segment: Segment, + pub record_segment: Segment, + pub vector_segment: Segment, +} + +impl TryFrom for CollectionAndSegments { + type Error = ConversionError; + + fn try_from(value: chroma_proto::ScanOperator) -> Result { + Ok(Self { + collection: value + .collection + .ok_or(ConversionError::DecodeError)? + .try_into() + .map_err(|_| ConversionError::DecodeError)?, + metadata_segment: value + .metadata + .ok_or(ConversionError::DecodeError)? + .try_into() + .map_err(|_| ConversionError::DecodeError)?, + record_segment: value + .record + .ok_or(ConversionError::DecodeError)? + .try_into() + .map_err(|_| ConversionError::DecodeError)?, + vector_segment: value + .knn + .ok_or(ConversionError::DecodeError)? + .try_into() + .map_err(|_| ConversionError::DecodeError)?, + }) + } +} + #[cfg(test)] mod test { use super::*; diff --git a/rust/types/src/segment.rs b/rust/types/src/segment.rs index 29fba56af84..8858365554a 100644 --- a/rust/types/src/segment.rs +++ b/rust/types/src/segment.rs @@ -6,6 +6,7 @@ use crate::chroma_proto; use chroma_error::{ChromaError, ErrorCodes}; use std::{collections::HashMap, str::FromStr}; use thiserror::Error; +use tonic::Status; use uuid::Uuid; /// SegmentUuid is a wrapper around Uuid to provide a type for the segment id. @@ -106,6 +107,12 @@ impl ChromaError for SegmentConversionError { } } +impl From for Status { + fn from(value: SegmentConversionError) -> Self { + Status::invalid_argument(value.to_string()) + } +} + impl TryFrom for Segment { type Error = SegmentConversionError; diff --git a/rust/types/src/types.rs b/rust/types/src/types.rs index ddd07d91b89..cc81189f5b6 100644 --- a/rust/types/src/types.rs +++ b/rust/types/src/types.rs @@ -28,7 +28,7 @@ pub enum ConversionError { impl ChromaError for ConversionError { fn code(&self) -> ErrorCodes { match self { - ConversionError::DecodeError => ErrorCodes::Internal, + ConversionError::DecodeError => ErrorCodes::InvalidArgument, } } } diff --git a/rust/worker/Cargo.toml b/rust/worker/Cargo.toml index d0541f65869..800093d306b 100644 --- a/rust/worker/Cargo.toml +++ b/rust/worker/Cargo.toml @@ -61,13 +61,14 @@ chroma-distance = { workspace = true } random-port = "0.1.1" serial_test = "3.2.0" +criterion = { workspace = true } +indicatif = { workspace = true } +proptest = { workspace = true } +proptest-state-machine = { workspace = true } +shuttle = { workspace = true } rand = { workspace = true } rand_xorshift = { workspace = true } tempfile = { workspace = true } -shuttle = { workspace = true } -proptest = { workspace = true } -proptest-state-machine = { workspace = true } -criterion = { workspace = true } chroma-benchmark = { workspace = true } @@ -75,10 +76,18 @@ chroma-benchmark = { workspace = true } name = "filter" harness = false +[[bench]] +name = "get" +harness = false + [[bench]] name = "limit" harness = false +[[bench]] +name = "query" +harness = false + [[bench]] name = "spann" harness = false diff --git a/rust/worker/benches/filter.rs b/rust/worker/benches/filter.rs index 4accb804818..57f54644b26 100644 --- a/rust/worker/benches/filter.rs +++ b/rust/worker/benches/filter.rs @@ -9,7 +9,7 @@ use criterion::Criterion; use criterion::{criterion_group, criterion_main}; use worker::execution::operator::Operator; use worker::execution::operators::filter::{FilterInput, FilterOperator}; -use worker::log::test::{upsert_generator, LogGenerator}; +use worker::log::test::upsert_generator; use worker::segment::test::TestSegment; fn baseline_where_clauses() -> Vec<(&'static str, Option)> { @@ -71,14 +71,13 @@ fn baseline_where_clauses() -> Vec<(&'static str, Option)> { fn bench_filter(criterion: &mut Criterion) { let runtime = tokio_multi_thread(); - let logen = LogGenerator { - generator: upsert_generator, - }; for record_count in [1000, 10000, 100000] { let test_segment = runtime.block_on(async { let mut segment = TestSegment::default(); - segment.populate_with_generator(record_count, &logen).await; + segment + .populate_with_generator(record_count, upsert_generator) + .await; segment }); diff --git a/rust/worker/benches/get.rs b/rust/worker/benches/get.rs new file mode 100644 index 00000000000..263f4d07f80 --- /dev/null +++ b/rust/worker/benches/get.rs @@ -0,0 +1,214 @@ +#[allow(dead_code)] +mod load; + +use chroma_benchmark::benchmark::{bench_run, tokio_multi_thread}; +use chroma_config::Configurable; +use criterion::{criterion_group, criterion_main, Criterion}; +use load::{ + all_projection, always_false_filter_for_modulo_metadata, + always_true_filter_for_modulo_metadata, empty_fetch_log, offset_limit, sift1m_segments, + trivial_filter, trivial_limit, trivial_projection, +}; +use worker::{ + config::RootConfig, + execution::{ + dispatcher::Dispatcher, + orchestration::{get::GetOrchestrator, orchestrator::Orchestrator}, + }, + segment::test::TestSegment, + system::{ComponentHandle, System}, +}; + +fn trivial_get( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> GetOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + GetOrchestrator::new( + blockfile_provider, + dispatcher_handle, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + trivial_filter(), + trivial_limit(), + trivial_projection(), + ) +} + +fn get_false_filter( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> GetOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + GetOrchestrator::new( + blockfile_provider, + dispatcher_handle, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + always_false_filter_for_modulo_metadata(), + trivial_limit(), + trivial_projection(), + ) +} + +fn get_true_filter( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> GetOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + GetOrchestrator::new( + blockfile_provider, + dispatcher_handle, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + always_true_filter_for_modulo_metadata(), + trivial_limit(), + trivial_projection(), + ) +} + +fn get_true_filter_limit( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> GetOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + GetOrchestrator::new( + blockfile_provider, + dispatcher_handle, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + always_true_filter_for_modulo_metadata(), + offset_limit(), + trivial_projection(), + ) +} + +fn get_true_filter_limit_projection( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> GetOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + GetOrchestrator::new( + blockfile_provider, + dispatcher_handle, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + always_true_filter_for_modulo_metadata(), + offset_limit(), + all_projection(), + ) +} + +async fn bench_routine(input: (System, GetOrchestrator, Vec)) { + let (system, orchestrator, expected_ids) = input; + let output = orchestrator + .run(system) + .await + .expect("Orchestrator should not fail"); + assert_eq!( + output + .records + .into_iter() + .map(|record| record.id) + .collect::>(), + expected_ids + ); +} + +fn bench_get(criterion: &mut Criterion) { + let runtime = tokio_multi_thread(); + let test_segments = runtime.block_on(sift1m_segments()); + + let config = RootConfig::default(); + let system = System::default(); + let dispatcher = runtime + .block_on(Dispatcher::try_from_config( + &config.query_service.dispatcher, + )) + .expect("Should be able to initialize dispatcher"); + let dispatcher_handle = runtime.block_on(async { system.start_component(dispatcher) }); + + let trivial_get_setup = || { + ( + system.clone(), + trivial_get(test_segments.clone(), dispatcher_handle.clone()), + (0..100).map(|id| id.to_string()).collect(), + ) + }; + let get_false_filter_setup = || { + ( + system.clone(), + get_false_filter(test_segments.clone(), dispatcher_handle.clone()), + Vec::new(), + ) + }; + let get_true_filter_setup = || { + ( + system.clone(), + get_true_filter(test_segments.clone(), dispatcher_handle.clone()), + (0..100).map(|id| id.to_string()).collect(), + ) + }; + let get_true_filter_limit_setup = || { + ( + system.clone(), + get_true_filter_limit(test_segments.clone(), dispatcher_handle.clone()), + (100..200).map(|id| id.to_string()).collect(), + ) + }; + let get_true_filter_limit_projection_setup = || { + ( + system.clone(), + get_true_filter_limit_projection(test_segments.clone(), dispatcher_handle.clone()), + (100..200).map(|id| id.to_string()).collect(), + ) + }; + + bench_run( + "test-trivial-get", + criterion, + &runtime, + trivial_get_setup, + bench_routine, + ); + bench_run( + "test-get-false-filter", + criterion, + &runtime, + get_false_filter_setup, + bench_routine, + ); + bench_run( + "test-get-true-filter", + criterion, + &runtime, + get_true_filter_setup, + bench_routine, + ); + bench_run( + "test-get-true-filter-limit", + criterion, + &runtime, + get_true_filter_limit_setup, + bench_routine, + ); + bench_run( + "test-get-true-filter-limit-projection", + criterion, + &runtime, + get_true_filter_limit_projection_setup, + bench_routine, + ); +} +criterion_group!(benches, bench_get); +criterion_main!(benches); diff --git a/rust/worker/benches/limit.rs b/rust/worker/benches/limit.rs index a19e481712c..7a13c174acc 100644 --- a/rust/worker/benches/limit.rs +++ b/rust/worker/benches/limit.rs @@ -4,21 +4,20 @@ use criterion::Criterion; use criterion::{criterion_group, criterion_main}; use worker::execution::operator::Operator; use worker::execution::operators::limit::{LimitInput, LimitOperator}; -use worker::log::test::{upsert_generator, LogGenerator}; +use worker::log::test::upsert_generator; use worker::segment::test::TestSegment; const FETCH: usize = 100; fn bench_limit(criterion: &mut Criterion) { let runtime = tokio_multi_thread(); - let logen = LogGenerator { - generator: upsert_generator, - }; for record_count in [1000, 10000, 100000] { let test_segment = runtime.block_on(async { let mut segment = TestSegment::default(); - segment.populate_with_generator(record_count, &logen).await; + segment + .populate_with_generator(record_count, upsert_generator) + .await; segment }); diff --git a/rust/worker/benches/load.rs b/rust/worker/benches/load.rs new file mode 100644 index 00000000000..aef7f8e0551 --- /dev/null +++ b/rust/worker/benches/load.rs @@ -0,0 +1,148 @@ +use chroma_benchmark::datasets::sift::Sift1MData; +use chroma_types::{ + Chunk, CollectionUuid, DirectWhereComparison, LogRecord, MetadataSetValue, Operation, + OperationRecord, SetOperator, Where, WhereComparison, +}; +use indicatif::ProgressIterator; +use worker::{ + execution::operators::{ + fetch_log::FetchLogOperator, filter::FilterOperator, limit::LimitOperator, + projection::ProjectionOperator, + }, + log::{ + log::{InMemoryLog, Log}, + test::modulo_metadata, + }, + segment::test::TestSegment, +}; + +const DATA_CHUNK_SIZE: usize = 10000; + +pub async fn sift1m_segments() -> TestSegment { + let mut segments = TestSegment::default(); + let mut sift1m = Sift1MData::init() + .await + .expect("Should be able to download Sift1M data"); + + for chunk_start in (0..Sift1MData::collection_size()) + .step_by(DATA_CHUNK_SIZE) + .progress() + .with_message("Loading Sift1M Data") + { + let embedding_chunk = sift1m + .data_range(chunk_start..(chunk_start + DATA_CHUNK_SIZE)) + .await + .expect("Should be able to decode data chunk"); + + let log_records = embedding_chunk + .into_iter() + .enumerate() + .map(|(index, embedding)| LogRecord { + log_offset: (chunk_start + index) as i64, + record: OperationRecord { + id: (chunk_start + index).to_string(), + embedding: Some(embedding), + encoding: None, + metadata: Some(modulo_metadata(chunk_start + index)), + document: None, + operation: Operation::Add, + }, + }) + .collect::>(); + segments + .compact_log(Chunk::new(log_records.into()), chunk_start) + .await; + } + segments +} + +pub fn empty_fetch_log(collection_uuid: CollectionUuid) -> FetchLogOperator { + FetchLogOperator { + log_client: Log::InMemory(InMemoryLog::default()).into(), + batch_size: 100, + start_log_offset_id: 0, + maximum_fetch_count: Some(0), + collection_uuid, + } +} + +pub fn trivial_filter() -> FilterOperator { + FilterOperator { + query_ids: None, + where_clause: None, + } +} + +pub fn always_false_filter_for_modulo_metadata() -> FilterOperator { + FilterOperator { + query_ids: None, + where_clause: Some(Where::disjunction(vec![ + Where::DirectWhereComparison(DirectWhereComparison { + key: "is_even".to_string(), + comparison: WhereComparison::Set( + SetOperator::NotIn, + MetadataSetValue::Bool(vec![false, true]), + ), + }), + Where::DirectWhereComparison(DirectWhereComparison { + key: "modulo_3".to_string(), + comparison: WhereComparison::Set( + SetOperator::NotIn, + MetadataSetValue::Int(vec![0, 1, 2]), + ), + }), + ])), + } +} + +pub fn always_true_filter_for_modulo_metadata() -> FilterOperator { + FilterOperator { + query_ids: None, + where_clause: Some(Where::conjunction(vec![ + Where::DirectWhereComparison(DirectWhereComparison { + key: "is_even".to_string(), + comparison: WhereComparison::Set( + SetOperator::In, + MetadataSetValue::Bool(vec![false, true]), + ), + }), + Where::DirectWhereComparison(DirectWhereComparison { + key: "modulo_3".to_string(), + comparison: WhereComparison::Set( + SetOperator::In, + MetadataSetValue::Int(vec![0, 1, 2]), + ), + }), + ])), + } +} + +pub fn trivial_limit() -> LimitOperator { + LimitOperator { + skip: 0, + fetch: Some(100), + } +} + +pub fn offset_limit() -> LimitOperator { + LimitOperator { + skip: 100, + fetch: Some(100), + } +} + +pub fn trivial_projection() -> ProjectionOperator { + ProjectionOperator { + document: false, + embedding: false, + metadata: false, + } +} + +pub fn all_projection() -> ProjectionOperator { + ProjectionOperator { + document: true, + embedding: true, + metadata: true, + } +} diff --git a/rust/worker/benches/query.rs b/rust/worker/benches/query.rs new file mode 100644 index 00000000000..5fac8d05829 --- /dev/null +++ b/rust/worker/benches/query.rs @@ -0,0 +1,248 @@ +#[allow(dead_code)] +mod load; + +use chroma_benchmark::{ + benchmark::{bench_run, tokio_multi_thread}, + datasets::sift::Sift1MData, +}; +use chroma_config::Configurable; +use criterion::{criterion_group, criterion_main, Criterion}; +use futures::{stream, StreamExt, TryStreamExt}; +use load::{ + all_projection, always_false_filter_for_modulo_metadata, + always_true_filter_for_modulo_metadata, empty_fetch_log, sift1m_segments, trivial_filter, +}; +use rand::{seq::SliceRandom, thread_rng}; +use worker::{ + config::RootConfig, + execution::{ + dispatcher::Dispatcher, + operators::{knn::KnnOperator, knn_projection::KnnProjectionOperator}, + orchestration::{ + knn::KnnOrchestrator, + knn_filter::{KnnFilterOrchestrator, KnnFilterOutput}, + orchestrator::Orchestrator, + }, + }, + segment::test::TestSegment, + system::{ComponentHandle, System}, +}; + +fn trivial_knn_filter( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> KnnFilterOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let hnsw_provider = test_segments.hnsw_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + KnnFilterOrchestrator::new( + blockfile_provider, + dispatcher_handle, + hnsw_provider, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + trivial_filter(), + ) +} + +fn always_true_knn_filter( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> KnnFilterOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let hnsw_provider = test_segments.hnsw_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + KnnFilterOrchestrator::new( + blockfile_provider, + dispatcher_handle, + hnsw_provider, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + always_true_filter_for_modulo_metadata(), + ) +} + +fn always_false_knn_filter( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, +) -> KnnFilterOrchestrator { + let blockfile_provider = test_segments.blockfile_provider.clone(); + let hnsw_provider = test_segments.hnsw_provider.clone(); + let collection_uuid = test_segments.collection.collection_id; + KnnFilterOrchestrator::new( + blockfile_provider, + dispatcher_handle, + hnsw_provider, + 1000, + test_segments.into(), + empty_fetch_log(collection_uuid), + always_false_filter_for_modulo_metadata(), + ) +} + +fn knn( + test_segments: TestSegment, + dispatcher_handle: ComponentHandle, + knn_filter_output: KnnFilterOutput, + query: Vec, +) -> KnnOrchestrator { + KnnOrchestrator::new( + test_segments.blockfile_provider.clone(), + dispatcher_handle.clone(), + 1000, + knn_filter_output.clone(), + KnnOperator { + embedding: query, + fetch: Sift1MData::k() as u32, + }, + KnnProjectionOperator { + projection: all_projection(), + distance: true, + }, + ) +} + +async fn bench_routine( + input: ( + System, + KnnFilterOrchestrator, + impl Fn(KnnFilterOutput) -> Vec<(KnnOrchestrator, Vec)>, + ), +) { + let (system, knn_filter, knn_constructor) = input; + let knn_filter_output = knn_filter + .run(system.clone()) + .await + .expect("Orchestrator should not fail"); + let (knns, _expected): (Vec<_>, Vec<_>) = + knn_constructor(knn_filter_output).into_iter().unzip(); + let _results = stream::iter(knns.into_iter().map(|knn| knn.run(system.clone()))) + .buffered(32) + .try_collect::>() + .await + .expect("Orchestrators should not fail"); + // TODO: verify recall +} + +fn bench_query(criterion: &mut Criterion) { + let runtime = tokio_multi_thread(); + let test_segments = runtime.block_on(sift1m_segments()); + + let config = RootConfig::default(); + let system = System::default(); + let dispatcher = runtime + .block_on(Dispatcher::try_from_config( + &config.query_service.dispatcher, + )) + .expect("Should be able to initialize dispatcher"); + let dispatcher_handle = runtime.block_on(async { system.start_component(dispatcher) }); + + let mut sift1m = runtime + .block_on(Sift1MData::init()) + .expect("Should be able to download Sift1M data"); + let mut sift1m_queries = runtime + .block_on(sift1m.query()) + .expect("Should be able to load Sift1M queries"); + + sift1m_queries.as_mut_slice().shuffle(&mut thread_rng()); + + let trivial_knn_setup = || { + ( + system.clone(), + trivial_knn_filter(test_segments.clone(), dispatcher_handle.clone().clone()), + |knn_filter_output: KnnFilterOutput| { + sift1m_queries + .iter() + .take(4) + .map(|(query, expected)| { + ( + knn( + test_segments.clone(), + dispatcher_handle.clone(), + knn_filter_output.clone(), + query.clone(), + ), + expected.clone(), + ) + }) + .collect() + }, + ) + }; + + let true_filter_knn_setup = || { + ( + system.clone(), + always_true_knn_filter(test_segments.clone(), dispatcher_handle.clone().clone()), + |knn_filter_output: KnnFilterOutput| { + sift1m_queries + .iter() + .take(4) + .map(|(query, expected)| { + ( + knn( + test_segments.clone(), + dispatcher_handle.clone(), + knn_filter_output.clone(), + query.clone(), + ), + expected.clone(), + ) + }) + .collect() + }, + ) + }; + + let false_filter_knn_setup = || { + ( + system.clone(), + always_false_knn_filter(test_segments.clone(), dispatcher_handle.clone().clone()), + |knn_filter_output: KnnFilterOutput| { + sift1m_queries + .iter() + .take(4) + .map(|(query, _)| { + ( + knn( + test_segments.clone(), + dispatcher_handle.clone(), + knn_filter_output.clone(), + query.clone(), + ), + Vec::new(), + ) + }) + .collect() + }, + ) + }; + + bench_run( + "test-trivial-knn", + criterion, + &runtime, + trivial_knn_setup, + bench_routine, + ); + + bench_run( + "test-true-filter-knn", + criterion, + &runtime, + true_filter_knn_setup, + bench_routine, + ); + + bench_run( + "test-false-filter-knn", + criterion, + &runtime, + false_filter_knn_setup, + bench_routine, + ); +} +criterion_group!(benches, bench_query); +criterion_main!(benches); diff --git a/rust/worker/src/compactor/compaction_manager.rs b/rust/worker/src/compactor/compaction_manager.rs index 380cb5a978b..d69d29ec68d 100644 --- a/rust/worker/src/compactor/compaction_manager.rs +++ b/rust/worker/src/compactor/compaction_manager.rs @@ -4,6 +4,7 @@ use crate::compactor::types::CompactionJob; use crate::compactor::types::ScheduleMessage; use crate::config::CompactionServiceConfig; use crate::execution::dispatcher::Dispatcher; +use crate::execution::orchestration::orchestrator::Orchestrator; use crate::execution::orchestration::CompactOrchestrator; use crate::execution::orchestration::CompactionResponse; use crate::log::log::Log; @@ -115,7 +116,6 @@ impl CompactionManager { Some(ref system) => { let orchestrator = CompactOrchestrator::new( compaction_job.clone(), - system.clone(), compaction_job.collection_id, self.log.clone(), self.sysdb.clone(), @@ -129,14 +129,14 @@ impl CompactionManager { self.max_partition_size, ); - match orchestrator.run().await { + match orchestrator.run(system.clone()).await { Ok(result) => { tracing::info!("Compaction Job completed: {:?}", result); return Ok(result); } Err(e) => { tracing::error!("Compaction Job failed: {:?}", e); - return Err(e); + return Err(Box::new(e)); } } } @@ -280,7 +280,7 @@ impl Component for CompactionManager { self.compaction_manager_queue_size } - async fn on_start(&mut self, ctx: &crate::system::ComponentContext) -> () { + async fn start(&mut self, ctx: &crate::system::ComponentContext) -> () { println!("Starting CompactionManager"); ctx.scheduler .schedule(ScheduleMessage {}, self.compaction_interval, ctx, || { diff --git a/rust/worker/src/config.rs b/rust/worker/src/config.rs index 79107db3b1d..63b9b524e8b 100644 --- a/rust/worker/src/config.rs +++ b/rust/worker/src/config.rs @@ -12,7 +12,7 @@ const DEFAULT_CONFIG_PATH: &str = "./chroma_config.yaml"; /// variables take precedence over values in the YAML file. /// By default, it is read from the current working directory, /// with the filename chroma_config.yaml. -pub(crate) struct RootConfig { +pub struct RootConfig { // The root config object wraps the worker config object so that // we can share the same config file between multiple services. pub query_service: QueryServiceConfig, @@ -78,6 +78,12 @@ impl RootConfig { } } +impl Default for RootConfig { + fn default() -> Self { + Self::load() + } +} + #[derive(Deserialize)] /// # Description /// The primary config for the worker service. @@ -89,7 +95,7 @@ impl RootConfig { /// For example, to set my_ip, you would set CHROMA_WORKER__MY_IP. /// Each submodule that needs to be configured from the config object should implement the Configurable trait and /// have its own field in this struct for its Config struct. -pub(crate) struct QueryServiceConfig { +pub struct QueryServiceConfig { pub(crate) service_name: String, pub(crate) otel_endpoint: String, #[allow(dead_code)] @@ -102,7 +108,7 @@ pub(crate) struct QueryServiceConfig { pub(crate) sysdb: crate::sysdb::config::SysDbConfig, pub(crate) storage: chroma_storage::config::StorageConfig, pub(crate) log: crate::log::config::LogConfig, - pub(crate) dispatcher: crate::execution::config::DispatcherConfig, + pub dispatcher: crate::execution::config::DispatcherConfig, pub(crate) blockfile_provider: chroma_blockstore::config::BlockfileProviderConfig, pub(crate) hnsw_provider: chroma_index::config::HnswProviderConfig, } @@ -118,7 +124,7 @@ pub(crate) struct QueryServiceConfig { /// For example, to set my_ip, you would set CHROMA_COMPACTOR__MY_IP. /// Each submodule that needs to be configured from the config object should implement the Configurable trait and /// have its own field in this struct for its Config struct. -pub(crate) struct CompactionServiceConfig { +pub struct CompactionServiceConfig { pub(crate) service_name: String, pub(crate) otel_endpoint: String, pub(crate) my_member_id: String, diff --git a/rust/worker/src/execution/config.rs b/rust/worker/src/execution/config.rs index d8550dc41bc..ada5bf1b6ad 100644 --- a/rust/worker/src/execution/config.rs +++ b/rust/worker/src/execution/config.rs @@ -1,7 +1,7 @@ use serde::Deserialize; #[derive(Deserialize)] -pub(crate) struct DispatcherConfig { +pub struct DispatcherConfig { pub(crate) num_worker_threads: usize, pub(crate) dispatcher_queue_size: usize, pub(crate) worker_queue_size: usize, diff --git a/rust/worker/src/execution/dispatcher.rs b/rust/worker/src/execution/dispatcher.rs index 3ce1b57444f..a7230f121f6 100644 --- a/rust/worker/src/execution/dispatcher.rs +++ b/rust/worker/src/execution/dispatcher.rs @@ -51,7 +51,7 @@ use tracing::{trace_span, Instrument, Span}; coarser work-stealing, and other optimizations. */ #[derive(Debug)] -pub(crate) struct Dispatcher { +pub struct Dispatcher { task_queue: Vec, waiters: Vec, n_worker_threads: usize, @@ -188,7 +188,7 @@ impl Component for Dispatcher { self.queue_size } - async fn on_start(&mut self, ctx: &ComponentContext) { + async fn start(&mut self, ctx: &ComponentContext) { self.spawn_workers(&mut ctx.system.clone(), ctx.receiver()); } } @@ -314,7 +314,7 @@ mod tests { 1000 } - async fn on_start(&mut self, ctx: &ComponentContext) { + async fn start(&mut self, ctx: &ComponentContext) { // dispatch a new task every DISPATCH_FREQUENCY_MS for DISPATCH_COUNT times let duration = std::time::Duration::from_millis(DISPATCH_FREQUENCY_MS); ctx.scheduler @@ -377,7 +377,7 @@ mod tests { 1000 } - async fn on_start(&mut self, ctx: &ComponentContext) { + async fn start(&mut self, ctx: &ComponentContext) { // dispatch a new task every DISPATCH_FREQUENCY_MS for DISPATCH_COUNT times let duration = std::time::Duration::from_millis(DISPATCH_FREQUENCY_MS); ctx.scheduler diff --git a/rust/worker/src/execution/mod.rs b/rust/worker/src/execution/mod.rs index 5c6ca5567b7..91a2f089e27 100644 --- a/rust/worker/src/execution/mod.rs +++ b/rust/worker/src/execution/mod.rs @@ -1,8 +1,8 @@ pub(crate) mod config; -pub(crate) mod dispatcher; -pub(crate) mod orchestration; mod worker_thread; // Required for benchmark +pub mod dispatcher; pub mod operator; pub mod operators; +pub mod orchestration; diff --git a/rust/worker/src/execution/operator.rs b/rust/worker/src/execution/operator.rs index d82aaaec501..9bc0e631492 100644 --- a/rust/worker/src/execution/operator.rs +++ b/rust/worker/src/execution/operator.rs @@ -51,15 +51,6 @@ where } } -impl TaskError -where - Err: Debug + ChromaError + 'static, -{ - pub(super) fn boxed(self) -> Box { - Box::new(self) - } -} - /// A task result is a wrapper around the result of a task. /// It contains the task id for tracking purposes. #[derive(Debug)] @@ -94,12 +85,12 @@ where } /// A message type used by the dispatcher to send tasks to worker threads. -pub(crate) type TaskMessage = Box; +pub type TaskMessage = Box; /// A task wrapper is a trait that can be used to run a task. We use it to /// erase the I, O types from the Task struct so that tasks. #[async_trait] -pub(crate) trait TaskWrapper: Send + Debug { +pub trait TaskWrapper: Send + Debug { fn get_name(&self) -> &'static str; async fn run(&self); #[allow(dead_code)] @@ -264,7 +255,7 @@ mod tests { 1000 } - async fn on_start(&mut self, ctx: &ComponentContext) { + async fn start(&mut self, ctx: &ComponentContext) { let task = wrap(Box::new(MockOperator {}), (), ctx.receiver()); self.dispatcher.send(task, None).await.unwrap(); } diff --git a/rust/worker/src/execution/operators/count_records.rs b/rust/worker/src/execution/operators/count_records.rs index cf67764c2dd..ae468218985 100644 --- a/rust/worker/src/execution/operators/count_records.rs +++ b/rust/worker/src/execution/operators/count_records.rs @@ -2,12 +2,12 @@ use crate::{ execution::operator::Operator, segment::record_segment::{RecordSegmentReader, RecordSegmentReaderCreationError}, }; +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::{Chunk, LogRecord, Operation, Segment}; use std::collections::HashSet; use thiserror::Error; -use tonic::async_trait; #[derive(Debug)] pub(crate) struct CountRecordsOperator {} diff --git a/rust/worker/src/execution/operators/fetch_log.rs b/rust/worker/src/execution/operators/fetch_log.rs index 22e3198929d..e49f06ae0e3 100644 --- a/rust/worker/src/execution/operators/fetch_log.rs +++ b/rust/worker/src/execution/operators/fetch_log.rs @@ -1,9 +1,9 @@ use std::time::{SystemTime, SystemTimeError, UNIX_EPOCH}; +use async_trait::async_trait; use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::{Chunk, CollectionUuid, LogRecord}; use thiserror::Error; -use tonic::async_trait; use tracing::trace; use crate::{ @@ -32,7 +32,7 @@ use crate::{ /// It should be run at the start of an orchestrator to get the latest data of a collection #[derive(Clone, Debug)] pub struct FetchLogOperator { - pub(crate) log_client: Box, + pub log_client: Box, pub batch_size: u32, pub start_log_offset_id: u32, pub maximum_fetch_count: Option, @@ -126,20 +126,20 @@ mod tests { fn setup_in_memory_log() -> (CollectionUuid, Box) { let collection_id = CollectionUuid::new(); let mut in_memory_log = InMemoryLog::new(); - let generator = LogGenerator { - generator: upsert_generator, - }; - generator.generate_vec(0..10).into_iter().for_each(|log| { - in_memory_log.add_log( - collection_id, - InternalLogRecord { + upsert_generator + .generate_vec(0..10) + .into_iter() + .for_each(|log| { + in_memory_log.add_log( collection_id, - log_offset: log.log_offset, - log_ts: log.log_offset, - record: log, - }, - ) - }); + InternalLogRecord { + collection_id, + log_offset: log.log_offset, + log_ts: log.log_offset, + record: log, + }, + ) + }); (collection_id, Box::new(Log::InMemory(in_memory_log))) } diff --git a/rust/worker/src/execution/operators/fetch_segment.rs b/rust/worker/src/execution/operators/fetch_segment.rs deleted file mode 100644 index d7b49f17e53..00000000000 --- a/rust/worker/src/execution/operators/fetch_segment.rs +++ /dev/null @@ -1,137 +0,0 @@ -use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::{Collection, CollectionUuid, Segment, SegmentScope, SegmentType, SegmentUuid}; -use thiserror::Error; -use tonic::async_trait; -use tracing::trace; - -use crate::{ - execution::operator::{Operator, OperatorType}, - sysdb::sysdb::{GetCollectionsError, GetSegmentsError, SysDb}, -}; - -/// The `FetchSegmentOperator` fetches collection and segment information from SysDB -/// -/// # Parameters -/// - `sysdb`: The SysDB reader -/// - `*_uuid`: The uuids of the collection and segments -/// - `collection_version`: The version of the collection to verify against -/// -/// # Inputs -/// - No input is required -/// -/// # Outputs -/// - `collection`: The collection information -/// - `*_segment`: The segment information -/// -/// # Usage -/// It should be run at the start of an orchestrator to get the latest data of a collection -#[derive(Clone, Debug)] -pub struct FetchSegmentOperator { - pub(crate) sysdb: Box, - pub collection_uuid: CollectionUuid, - pub collection_version: u32, - pub metadata_uuid: SegmentUuid, - pub record_uuid: SegmentUuid, - pub vector_uuid: SegmentUuid, -} - -type FetchSegmentInput = (); - -#[derive(Clone, Debug)] -pub struct FetchSegmentOutput { - pub collection: Collection, - pub metadata_segment: Segment, - pub record_segment: Segment, - pub vector_segment: Segment, -} - -#[derive(Error, Debug)] -pub enum FetchSegmentError { - #[error("Error when getting collection: {0}")] - GetCollection(#[from] GetCollectionsError), - #[error("Error when getting segment: {0}")] - GetSegment(#[from] GetSegmentsError), - #[error("No collection found")] - NoCollection, - #[error("No segment found")] - NoSegment, - // The frontend relies on ths content of the error message here to detect version mismatch - // TODO: Refactor frontend to properly detect version mismatch - #[error("Collection version mismatch")] - VersionMismatch, -} - -impl ChromaError for FetchSegmentError { - fn code(&self) -> ErrorCodes { - match self { - FetchSegmentError::GetCollection(e) => e.code(), - FetchSegmentError::GetSegment(e) => e.code(), - FetchSegmentError::NoCollection => ErrorCodes::NotFound, - FetchSegmentError::NoSegment => ErrorCodes::NotFound, - FetchSegmentError::VersionMismatch => ErrorCodes::VersionMismatch, - } - } -} - -impl FetchSegmentOperator { - async fn get_collection(&self) -> Result { - let collection = self - .sysdb - .clone() - .get_collections(Some(self.collection_uuid), None, None, None) - .await? - .pop() - .ok_or(FetchSegmentError::NoCollection)?; - if collection.version != self.collection_version as i32 { - Err(FetchSegmentError::VersionMismatch) - } else { - Ok(collection) - } - } - async fn get_segment(&self, scope: SegmentScope) -> Result { - let segment_type = match scope { - SegmentScope::METADATA => SegmentType::BlockfileMetadata, - SegmentScope::RECORD => SegmentType::BlockfileRecord, - SegmentScope::SQLITE => unimplemented!("Unexpected Sqlite segment"), - SegmentScope::VECTOR => SegmentType::HnswDistributed, - }; - let segment_id = match scope { - SegmentScope::METADATA => self.metadata_uuid, - SegmentScope::RECORD => self.record_uuid, - SegmentScope::SQLITE => unimplemented!("Unexpected Sqlite segment"), - SegmentScope::VECTOR => self.vector_uuid, - }; - self.sysdb - .clone() - .get_segments( - Some(segment_id), - Some(segment_type.into()), - Some(scope), - self.collection_uuid, - ) - .await? - // Each scope should have a single segment - .pop() - .ok_or(FetchSegmentError::NoSegment) - } -} - -#[async_trait] -impl Operator for FetchSegmentOperator { - type Error = FetchSegmentError; - - fn get_type(&self) -> OperatorType { - OperatorType::IO - } - - async fn run(&self, _: &FetchSegmentInput) -> Result { - trace!("[{}]: {:?}", self.get_name(), self); - - Ok(FetchSegmentOutput { - collection: self.get_collection().await?, - metadata_segment: self.get_segment(SegmentScope::METADATA).await?, - record_segment: self.get_segment(SegmentScope::RECORD).await?, - vector_segment: self.get_segment(SegmentScope::VECTOR).await?, - }) - } -} diff --git a/rust/worker/src/execution/operators/filter.rs b/rust/worker/src/execution/operators/filter.rs index 635636b1137..408f7dc19f8 100644 --- a/rust/worker/src/execution/operators/filter.rs +++ b/rust/worker/src/execution/operators/filter.rs @@ -3,6 +3,7 @@ use std::{ ops::{BitAnd, BitOr, Bound}, }; +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::metadata::types::MetadataIndexError; @@ -13,7 +14,6 @@ use chroma_types::{ }; use roaring::RoaringBitmap; use thiserror::Error; -use tonic::async_trait; use tracing::{trace, Instrument, Span}; use crate::{ @@ -515,12 +515,11 @@ mod tests { /// - Compacted: Delete [1..=10] deletion, add [11..=50] async fn setup_filter_input() -> FilterInput { let mut test_segment = TestSegment::default(); - let generator = LogGenerator { - generator: add_delete_generator, - }; - test_segment.populate_with_generator(60, &generator).await; + test_segment + .populate_with_generator(60, add_delete_generator) + .await; FilterInput { - logs: generator.generate_chunk(61..=120), + logs: add_delete_generator.generate_chunk(61..=120), blockfile_provider: test_segment.blockfile_provider, metadata_segment: test_segment.metadata_segment, record_segment: test_segment.record_segment, diff --git a/rust/worker/src/execution/operators/knn_hnsw.rs b/rust/worker/src/execution/operators/knn_hnsw.rs index 2618ce7cd4d..4e80971d026 100644 --- a/rust/worker/src/execution/operators/knn_hnsw.rs +++ b/rust/worker/src/execution/operators/knn_hnsw.rs @@ -1,8 +1,8 @@ +use async_trait::async_trait; use chroma_distance::{normalize, DistanceFunction}; use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::SignedRoaringBitmap; use thiserror::Error; -use tonic::async_trait; use crate::{ execution::operator::Operator, segment::distributed_hnsw_segment::DistributedHNSWSegmentReader, diff --git a/rust/worker/src/execution/operators/knn_log.rs b/rust/worker/src/execution/operators/knn_log.rs index aee11654529..ef333fdb44b 100644 --- a/rust/worker/src/execution/operators/knn_log.rs +++ b/rust/worker/src/execution/operators/knn_log.rs @@ -1,11 +1,11 @@ use std::collections::BinaryHeap; +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_distance::{normalize, DistanceFunction}; use chroma_error::ChromaError; use chroma_types::{MaterializedLogOperation, Segment, SignedRoaringBitmap}; use thiserror::Error; -use tonic::async_trait; use crate::{ execution::operator::Operator, @@ -143,11 +143,8 @@ mod tests { log_offset_ids: SignedRoaringBitmap, ) -> KnnLogInput { let test_segment = TestSegment::default(); - let generator = LogGenerator { - generator: upsert_generator, - }; KnnLogInput { - logs: generator.generate_chunk(1..=100), + logs: upsert_generator.generate_chunk(1..=100), blockfile_provider: test_segment.blockfile_provider, record_segment: test_segment.record_segment, distance_function: metric, diff --git a/rust/worker/src/execution/operators/knn_merge.rs b/rust/worker/src/execution/operators/knn_merge.rs index 545589a511d..fa3981328bb 100644 --- a/rust/worker/src/execution/operators/knn_merge.rs +++ b/rust/worker/src/execution/operators/knn_merge.rs @@ -1,4 +1,4 @@ -use tonic::async_trait; +use async_trait::async_trait; use crate::execution::operator::Operator; diff --git a/rust/worker/src/execution/operators/knn_projection.rs b/rust/worker/src/execution/operators/knn_projection.rs index ee39aaf59ed..7883006320f 100644 --- a/rust/worker/src/execution/operators/knn_projection.rs +++ b/rust/worker/src/execution/operators/knn_projection.rs @@ -1,8 +1,8 @@ +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::ChromaError; use chroma_types::Segment; use thiserror::Error; -use tonic::async_trait; use tracing::trace; use crate::execution::{operator::Operator, operators::projection::ProjectionInput}; @@ -145,12 +145,11 @@ mod tests { record_distances: Vec, ) -> KnnProjectionInput { let mut test_segment = TestSegment::default(); - let generator = LogGenerator { - generator: upsert_generator, - }; - test_segment.populate_with_generator(100, &generator).await; + test_segment + .populate_with_generator(100, upsert_generator) + .await; KnnProjectionInput { - logs: generator.generate_chunk(81..=120), + logs: upsert_generator.generate_chunk(81..=120), blockfile_provider: test_segment.blockfile_provider, record_segment: test_segment.record_segment, record_distances, diff --git a/rust/worker/src/execution/operators/limit.rs b/rust/worker/src/execution/operators/limit.rs index b5709dcddbd..eb2d2b287e4 100644 --- a/rust/worker/src/execution/operators/limit.rs +++ b/rust/worker/src/execution/operators/limit.rs @@ -1,11 +1,12 @@ use std::{cmp::Ordering, num::TryFromIntError, sync::atomic}; +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::{Chunk, LogRecord, MaterializedLogOperation, Segment, SignedRoaringBitmap}; +use futures::StreamExt; use roaring::RoaringBitmap; use thiserror::Error; -use tonic::async_trait; use tracing::{trace, Instrument, Span}; use crate::{ @@ -141,44 +142,38 @@ impl<'me> SeekScanner<'me> { // Seek the start in the log and record segment, then scan for the specified number of offset ids async fn seek_and_scan(&self, skip: u64, mut fetch: u64) -> Result { - let record_count = self.record_segment.count().await?; let starting_offset = self.seek_starting_offset(skip).await?; let mut log_index = self.log_offset_ids.rank(starting_offset) - self.log_offset_ids.contains(starting_offset) as u64; - let mut record_index = self - .record_segment - .get_offset_id_rank(starting_offset) - .await?; + let mut log_offset_id = self.log_offset_ids.select(u32::try_from(log_index)?); + let mut record_offset_stream = self.record_segment.get_offset_stream(starting_offset..); + let mut record_offset_id = record_offset_stream.next().await.transpose()?; let mut merged_result = Vec::new(); while fetch > 0 { - let log_offset_id = self.log_offset_ids.select(u32::try_from(log_index)?); - let record_offset_id = (record_index < record_count).then_some( - self.record_segment - .get_offset_id_at_index(record_index) - .await?, - ); match (log_offset_id, record_offset_id) { (_, Some(oid)) if self.mask.contains(oid) => { - record_index += 1; + record_offset_id = record_offset_stream.next().await.transpose()?; continue; } (Some(log_oid), Some(record_oid)) => { if log_oid < record_oid { merged_result.push(log_oid); log_index += 1; + log_offset_id = self.log_offset_ids.select(u32::try_from(log_index)?); } else { merged_result.push(record_oid); - record_index += 1; + record_offset_id = record_offset_stream.next().await.transpose()?; } } (None, Some(oid)) => { merged_result.push(oid); - record_index += 1; + record_offset_id = record_offset_stream.next().await.transpose()?; } (Some(oid), None) => { merged_result.push(oid); log_index += 1; + log_offset_id = self.log_offset_ids.select(u32::try_from(log_index)?); } _ => break, }; @@ -304,12 +299,11 @@ mod tests { compact_offset_ids: SignedRoaringBitmap, ) -> LimitInput { let mut test_segment = TestSegment::default(); - let generator = LogGenerator { - generator: upsert_generator, - }; - test_segment.populate_with_generator(100, &generator).await; + test_segment + .populate_with_generator(100, upsert_generator) + .await; LimitInput { - logs: generator.generate_chunk(31..=60), + logs: upsert_generator.generate_chunk(31..=60), blockfile_provider: test_segment.blockfile_provider, record_segment: test_segment.record_segment, log_offset_ids, diff --git a/rust/worker/src/execution/operators/mod.rs b/rust/worker/src/execution/operators/mod.rs index f0e02207c65..76eccedaab9 100644 --- a/rust/worker/src/execution/operators/mod.rs +++ b/rust/worker/src/execution/operators/mod.rs @@ -10,7 +10,6 @@ pub(super) mod write_segments; // Required for benchmark pub mod fetch_log; -pub mod fetch_segment; pub mod filter; pub mod knn; pub mod knn_hnsw; diff --git a/rust/worker/src/execution/operators/prefetch_record.rs b/rust/worker/src/execution/operators/prefetch_record.rs index b97594d778b..8a5deee479e 100644 --- a/rust/worker/src/execution/operators/prefetch_record.rs +++ b/rust/worker/src/execution/operators/prefetch_record.rs @@ -1,10 +1,8 @@ use std::collections::HashSet; -use chroma_blockstore::provider::BlockfileProvider; +use async_trait::async_trait; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::{Chunk, LogRecord, Segment}; use thiserror::Error; -use tonic::async_trait; use tracing::{trace, Instrument, Span}; use crate::{ @@ -16,16 +14,15 @@ use crate::{ }, }; +use super::projection::ProjectionInput; + /// The `PrefetchRecordOperator` prefetches the relevant records from the record segments to the cache /// /// # Parameters /// None /// /// # Inputs -/// - `logs`: The latest logs of the collection -/// - `blockfile_provider`: The blockfile provider -/// - `record_segment`: The record segment information -/// - `offset_ids`: The offset ids of the records to prefetch +/// Identical to ProjectionInput /// /// # Outputs /// None @@ -35,13 +32,7 @@ use crate::{ #[derive(Debug)] pub struct PrefetchRecordOperator {} -#[derive(Debug)] -pub struct PrefetchRecordInput { - pub logs: Chunk, - pub blockfile_provider: BlockfileProvider, - pub record_segment: Segment, - pub offset_ids: Vec, -} +pub type PrefetchRecordInput = ProjectionInput; pub type PrefetchRecordOutput = (); diff --git a/rust/worker/src/execution/operators/projection.rs b/rust/worker/src/execution/operators/projection.rs index f02f89acbfe..b555520b4de 100644 --- a/rust/worker/src/execution/operators/projection.rs +++ b/rust/worker/src/execution/operators/projection.rs @@ -42,7 +42,7 @@ pub struct ProjectionOperator { pub metadata: bool, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct ProjectionInput { pub logs: Chunk, pub blockfile_provider: BlockfileProvider, @@ -184,12 +184,11 @@ mod tests { /// - Compacted: Upsert [1..=100] async fn setup_projection_input(offset_ids: Vec) -> ProjectionInput { let mut test_segment = TestSegment::default(); - let generator = LogGenerator { - generator: upsert_generator, - }; - test_segment.populate_with_generator(100, &generator).await; + test_segment + .populate_with_generator(100, upsert_generator) + .await; ProjectionInput { - logs: generator.generate_chunk(81..=120), + logs: upsert_generator.generate_chunk(81..=120), blockfile_provider: test_segment.blockfile_provider, record_segment: test_segment.record_segment, offset_ids, diff --git a/rust/worker/src/execution/operators/spann_bf_pl.rs b/rust/worker/src/execution/operators/spann_bf_pl.rs index fdad8676eec..c274717f664 100644 --- a/rust/worker/src/execution/operators/spann_bf_pl.rs +++ b/rust/worker/src/execution/operators/spann_bf_pl.rs @@ -1,11 +1,11 @@ use std::collections::BinaryHeap; +use async_trait::async_trait; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::spann::types::SpannPosting; use chroma_types::SignedRoaringBitmap; use thiserror::Error; -use tonic::async_trait; use crate::execution::operator::Operator; diff --git a/rust/worker/src/execution/operators/spann_centers_search.rs b/rust/worker/src/execution/operators/spann_centers_search.rs index 064e5a9b762..6dc6d9edf06 100644 --- a/rust/worker/src/execution/operators/spann_centers_search.rs +++ b/rust/worker/src/execution/operators/spann_centers_search.rs @@ -1,8 +1,8 @@ +use async_trait::async_trait; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::spann::utils::rng_query; use thiserror::Error; -use tonic::async_trait; use crate::{ execution::operator::Operator, @@ -28,7 +28,7 @@ pub(crate) struct SpannCentersSearchOutput { } #[derive(Error, Debug)] -pub(crate) enum SpannCentersSearchError { +pub enum SpannCentersSearchError { #[error("Error creating spann segment reader")] SpannSegmentReaderCreationError, #[error("Error querying RNG")] diff --git a/rust/worker/src/execution/operators/spann_fetch_pl.rs b/rust/worker/src/execution/operators/spann_fetch_pl.rs index 16eaaee975e..0732ef364c4 100644 --- a/rust/worker/src/execution/operators/spann_fetch_pl.rs +++ b/rust/worker/src/execution/operators/spann_fetch_pl.rs @@ -1,7 +1,7 @@ +use async_trait::async_trait; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::spann::types::SpannPosting; use thiserror::Error; -use tonic::async_trait; use crate::{ execution::operator::{Operator, OperatorType}, @@ -22,7 +22,7 @@ pub(crate) struct SpannFetchPlOutput { } #[derive(Error, Debug)] -pub(crate) enum SpannFetchPlError { +pub enum SpannFetchPlError { #[error("Error creating spann segment reader")] SpannSegmentReaderCreationError, #[error("Error querying reader")] diff --git a/rust/worker/src/execution/operators/spann_knn_merge.rs b/rust/worker/src/execution/operators/spann_knn_merge.rs index c1ac147a04e..85b6fd42320 100644 --- a/rust/worker/src/execution/operators/spann_knn_merge.rs +++ b/rust/worker/src/execution/operators/spann_knn_merge.rs @@ -1,6 +1,6 @@ use std::{cmp::Ordering, collections::BinaryHeap}; -use tonic::async_trait; +use async_trait::async_trait; use crate::execution::operator::Operator; diff --git a/rust/worker/src/execution/orchestration/common.rs b/rust/worker/src/execution/orchestration/common.rs deleted file mode 100644 index ff72803f3e8..00000000000 --- a/rust/worker/src/execution/orchestration/common.rs +++ /dev/null @@ -1,32 +0,0 @@ -use crate::system::{Component, ComponentContext}; -use chroma_error::ChromaError; - -/// Terminate the orchestrator with an error -/// This function sends an error to the result channel and cancels the orchestrator -/// so it stops processing -/// # Arguments -/// * `result_channel` - The result channel to send the error to -/// * `error` - The error to send -/// * `ctx` - The component context -/// # Panics -/// This function panics if the result channel is not set -pub(super) fn terminate_with_error( - mut result_channel: Option>>, - error: E, - ctx: &ComponentContext, -) where - C: Component, - E: ChromaError, -{ - let result_channel = result_channel - .take() - .expect("Invariant violation. Result channel is not set."); - match result_channel.send(Err(error)) { - Ok(_) => (), - Err(_) => { - tracing::error!("Result channel dropped before sending error"); - } - } - // Cancel the orchestrator so it stops processing - ctx.cancellation_token.cancel(); -} diff --git a/rust/worker/src/execution/orchestration/compact.rs b/rust/worker/src/execution/orchestration/compact.rs index fa189c1008b..fb241c5f2fc 100644 --- a/rust/worker/src/execution/orchestration/compact.rs +++ b/rust/worker/src/execution/orchestration/compact.rs @@ -1,6 +1,9 @@ use super::super::operator::wrap; +use super::orchestrator::Orchestrator; use crate::compactor::CompactionJob; use crate::execution::dispatcher::Dispatcher; +use crate::execution::operator::TaskError; +use crate::execution::operator::TaskMessage; use crate::execution::operator::TaskResult; use crate::execution::operators::fetch_log::FetchLogError; use crate::execution::operators::fetch_log::FetchLogOperator; @@ -20,7 +23,6 @@ use crate::execution::operators::write_segments::WriteSegmentsInput; use crate::execution::operators::write_segments::WriteSegmentsOperator; use crate::execution::operators::write_segments::WriteSegmentsOperatorError; use crate::execution::operators::write_segments::WriteSegmentsOutput; -use crate::execution::orchestration::common::terminate_with_error; use crate::log::log::Log; use crate::segment::distributed_hnsw_segment::DistributedHNSWSegmentWriter; use crate::segment::metadata_segment::MetadataSegmentWriter; @@ -29,11 +31,10 @@ use crate::segment::record_segment::RecordSegmentWriter; use crate::sysdb::sysdb::GetCollectionsError; use crate::sysdb::sysdb::GetSegmentsError; use crate::sysdb::sysdb::SysDb; -use crate::system::Component; +use crate::system::ChannelError; +use crate::system::ComponentContext; use crate::system::ComponentHandle; use crate::system::Handler; -use crate::system::ReceiverForMessage; -use crate::system::System; use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::ChromaError; @@ -46,7 +47,8 @@ use std::sync::atomic; use std::sync::atomic::AtomicU32; use std::sync::Arc; use thiserror::Error; -use tracing::Span; +use tokio::sync::oneshot::error::RecvError; +use tokio::sync::oneshot::Sender; use uuid::Uuid; /** The state of the orchestrator. @@ -67,20 +69,25 @@ understand. We can always add more abstraction later if we need it. #[derive(Debug)] enum ExecutionState { Pending, - PullLogs, Partition, Write, Flush, Register, } +#[derive(Clone, Debug)] +struct CompactWriters { + metadata: MetadataSegmentWriter<'static>, + record: RecordSegmentWriter, + vector: Box, +} + #[derive(Debug)] pub struct CompactOrchestrator { id: Uuid, compaction_job: CompactionJob, state: ExecutionState, // Component Execution - system: System, collection_id: CollectionUuid, // Dependencies log: Box, @@ -93,16 +100,11 @@ pub struct CompactOrchestrator { // Dispatcher dispatcher: ComponentHandle, // Shared writers - writers: Option<( - RecordSegmentWriter, - Box, - MetadataSegmentWriter<'static>, - )>, + writers: Option, // number of write segments tasks num_write_tasks: i32, // Result Channel - result_channel: - Option>>>, + result_channel: Option>>, // Next offset id next_offset_id: Arc, max_compaction_size: usize, @@ -140,11 +142,35 @@ impl ChromaError for GetSegmentWritersError { } #[derive(Error, Debug)] -enum CompactionError { - #[error("Task dispatch failed")] - DispatchFailure, - #[error("Result channel dropped")] - ResultChannelDropped, +pub enum CompactionError { + #[error("Panic running task: {0}")] + Panic(String), + #[error("FetchLog error: {0}")] + FetchLog(#[from] FetchLogError), + #[error("Partition error: {0}")] + Partition(#[from] PartitionError), + #[error("WriteSegments error: {0}")] + WriteSegments(#[from] WriteSegmentsOperatorError), + #[error("Regester error: {0}")] + Register(#[from] RegisterError), + #[error("Error sending message through channel: {0}")] + Channel(#[from] ChannelError), + #[error("Error receiving final result: {0}")] + Result(#[from] RecvError), + #[error("{0}")] + Generic(#[from] Box), +} + +impl From> for CompactionError +where + E: Into, +{ + fn from(value: TaskError) -> Self { + match value { + TaskError::Panic(e) => CompactionError::Panic(e.unwrap_or_default()), + TaskError::TaskFailed(e) => e.into(), + } + } } impl ChromaError for CompactionError { @@ -167,16 +193,13 @@ impl CompactOrchestrator { #[allow(clippy::too_many_arguments)] pub fn new( compaction_job: CompactionJob, - system: System, collection_id: CollectionUuid, log: Box, sysdb: Box, blockfile_provider: BlockfileProvider, hnsw_index_provider: HnswIndexProvider, dispatcher: ComponentHandle, - result_channel: Option< - tokio::sync::oneshot::Sender>>, - >, + result_channel: Option>>, record_segment: Option, next_offset_id: Arc, max_compaction_size: usize, @@ -186,7 +209,6 @@ impl CompactOrchestrator { id: Uuid::new_v4(), compaction_job, state: ExecutionState::Pending, - system, collection_id, log, sysdb, @@ -204,76 +226,38 @@ impl CompactOrchestrator { } } - async fn fetch_log( - &mut self, - self_address: Box>>, - ctx: &crate::system::ComponentContext, - ) { - self.state = ExecutionState::PullLogs; - let operator = FetchLogOperator { - log_client: self.log.clone(), - batch_size: 100, - // Here we do not need to be inclusive since the compaction job - // offset is the one after the last compaction offset - start_log_offset_id: self.compaction_job.offset as u32, - maximum_fetch_count: Some(self.max_compaction_size as u32), - collection_uuid: self.collection_id, - }; - let task = wrap(Box::new(operator), (), self_address); - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - tracing::error!("Error dispatching pull logs for compaction {:?}", e); - terminate_with_error( - self.result_channel.take(), - Box::new(CompactionError::DispatchFailure), - ctx, - ); - } - } - } - async fn partition( &mut self, records: Chunk, - self_address: Box>>, + ctx: &crate::system::ComponentContext, ) { self.state = ExecutionState::Partition; let operator = PartitionOperator::new(); tracing::info!("Sending N Records: {:?}", records.len()); println!("Sending N Records: {:?}", records.len()); let input = PartitionInput::new(records, self.max_partition_size); - let task = wrap(operator, input, self_address); - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - tracing::error!("Error dispatching partition for compaction {:?}", e); - panic!( - "Invariant violation. Somehow the dispatcher receiver is dropped. Error: {:?}", - e - ) - } - } + let task = wrap(operator, input, ctx.receiver()); + self.send(task, ctx).await; } async fn write( &mut self, partitions: Vec>, - self_address: Box< - dyn ReceiverForMessage>, - >, ctx: &crate::system::ComponentContext, ) { self.state = ExecutionState::Write; - if let Err(e) = self.init_segment_writers().await { - tracing::error!("Error creating writers for compaction {:?}", e); - terminate_with_error(self.result_channel.take(), e, ctx); + let init_res = self.init_segment_writers().await; + if self.ok_or_terminate(init_res, ctx).is_none() { return; } let (record_segment_writer, hnsw_segment_writer, metadata_segment_writer) = match self.writers.clone() { - Some((rec, hnsw, mt)) => (Some(rec), Some(hnsw), Some(mt)), + Some(writers) => ( + Some(writers.record), + Some(writers.vector), + Some(writers.metadata), + ), None => (None, None, None), }; @@ -292,16 +276,8 @@ impl CompactOrchestrator { .clone(), self.next_offset_id.clone(), ); - let task = wrap(operator, input, self_address.clone()); - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - tracing::error!("Error dispatching writers for compaction {:?}", e); - panic!( - "Invariant violation. Somehow the dispatcher receiver is dropped. Error: {:?}", - e) - } - } + let task = wrap(operator, input, ctx.receiver()); + self.send(task, ctx).await; } } @@ -310,7 +286,7 @@ impl CompactOrchestrator { record_segment_writer: RecordSegmentWriter, hnsw_segment_writer: Box, metadata_segment_writer: MetadataSegmentWriter<'static>, - self_address: Box>>>, + ctx: &crate::system::ComponentContext, ) { self.state = ExecutionState::Flush; @@ -321,24 +297,15 @@ impl CompactOrchestrator { metadata_segment_writer, ); - let task = wrap(operator, input, self_address); - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - tracing::error!("Error dispatching flush to S3 for compaction {:?}", e); - panic!( - "Invariant violation. Somehow the dispatcher receiver is dropped. Error: {:?}", - e - ); - } - } + let task = wrap(operator, input, ctx.receiver()); + self.send(task, ctx).await; } async fn register( &mut self, log_position: i64, segment_flush_info: Arc<[SegmentFlushInfo]>, - self_address: Box>>, + ctx: &crate::system::ComponentContext, ) { self.state = ExecutionState::Register; let operator = RegisterOperator::new(); @@ -352,17 +319,8 @@ impl CompactOrchestrator { self.log.clone(), ); - let task = wrap(operator, input, self_address); - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - tracing::error!("Error dispatching register for compaction {:?}", e); - panic!( - "Invariant violation. Somehow the dispatcher receiver is dropped. Error: {:?}", - e - ); - } - } + let task = wrap(operator, input, ctx.receiver()); + self.send(task, ctx).await; } async fn init_segment_writers(&mut self) -> Result<(), Box> { @@ -490,42 +448,52 @@ impl CompactOrchestrator { return Err(Box::new(GetSegmentWritersError::HnswSegmentWriterError)); } }; - self.writers = Some(( - record_segment_writer, - hnsw_segment_writer, - mt_segment_writer, - )) + self.writers = Some(CompactWriters { + metadata: mt_segment_writer, + record: record_segment_writer, + vector: hnsw_segment_writer, + }) } Ok(()) } - - pub(crate) async fn run(mut self) -> Result> { - println!("Running compaction job: {:?}", self.compaction_job); - let (tx, rx) = tokio::sync::oneshot::channel(); - self.result_channel = Some(tx); - let mut handle = self.system.clone().start_component(self); - let result = rx.await; - handle.stop(); - result - .map_err(|_| Box::new(CompactionError::ResultChannelDropped) as Box)? - } } // ============== Component Implementation ============== #[async_trait] -impl Component for CompactOrchestrator { - fn get_name() -> &'static str { - "Compaction orchestrator" +impl Orchestrator for CompactOrchestrator { + type Output = CompactionResponse; + type Error = CompactionError; + + fn dispatcher(&self) -> ComponentHandle { + self.dispatcher.clone() } - fn queue_size(&self) -> usize { - 1000 // TODO: make configurable + fn initial_tasks(&self, ctx: &ComponentContext) -> Vec { + vec![wrap( + Box::new(FetchLogOperator { + log_client: self.log.clone(), + batch_size: 100, + // Here we do not need to be inclusive since the compaction job + // offset is the one after the last compaction offset + start_log_offset_id: self.compaction_job.offset as u32, + maximum_fetch_count: Some(self.max_compaction_size as u32), + collection_uuid: self.collection_id, + }), + (), + ctx.receiver(), + )] } - async fn on_start(&mut self, ctx: &crate::system::ComponentContext) -> () { - self.fetch_log(ctx.receiver(), ctx).await; + fn set_result_channel(&mut self, sender: Sender>) { + self.result_channel = Some(sender) + } + + fn take_result_channel(&mut self) -> Sender> { + self.result_channel + .take() + .expect("The result channel should be set before take") } } @@ -539,13 +507,9 @@ impl Handler> for CompactOrchestrator message: TaskResult, ctx: &crate::system::ComponentContext, ) { - let message = message.into_inner(); - let records = match message { - Ok(result) => result, - Err(e) => { - terminate_with_error(self.result_channel.take(), Box::new(e), ctx); - return; - } + let records = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(recs) => recs, + None => todo!(), }; tracing::info!("Pulled Records: {:?}", records.len()); let final_record_pulled = records.get(records.len() - 1); @@ -553,7 +517,7 @@ impl Handler> for CompactOrchestrator Some(record) => { self.pulled_log_offset = Some(record.log_offset); tracing::info!("Pulled Logs Up To Offset: {:?}", self.pulled_log_offset); - self.partition(records, ctx.receiver()).await; + self.partition(records, ctx).await; } None => { tracing::error!( @@ -574,16 +538,11 @@ impl Handler> for CompactOrchestrato message: TaskResult, ctx: &crate::system::ComponentContext, ) { - let message = message.into_inner(); - let records = match message { - Ok(result) => result.records, - Err(e) => { - tracing::error!("Error partitioning records: {:?}", e); - terminate_with_error(self.result_channel.take(), Box::new(e), ctx); - return; - } + let records = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(recs) => recs.records, + None => todo!(), }; - self.write(records, ctx.receiver(), ctx).await; + self.write(records, ctx).await; } } @@ -596,33 +555,22 @@ impl Handler> for Co message: TaskResult, ctx: &crate::system::ComponentContext, ) { - let message = message.into_inner(); - let output = match message { - Ok(output) => { - self.num_write_tasks -= 1; - output - } - Err(e) => { - tracing::error!("Error writing segments: {:?}", e); - terminate_with_error(self.result_channel.take(), Box::new(e), ctx); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; + self.num_write_tasks -= 1; if self.num_write_tasks == 0 { if let (Some(rec), Some(hnsw), Some(mt)) = ( output.record_segment_writer, output.hnsw_segment_writer, output.metadata_segment_writer, ) { - self.flush_s3(rec, hnsw, mt, ctx.receiver()).await; + self.flush_s3(rec, hnsw, mt, ctx).await; } else { // There is nothing to flush, proceed to register - self.register( - self.pulled_log_offset.unwrap(), - Arc::new([]), - ctx.receiver(), - ) - .await; + self.register(self.pulled_log_offset.unwrap(), Arc::new([]), ctx) + .await; } } } @@ -637,22 +585,16 @@ impl Handler>> for CompactOrchest message: TaskResult>, ctx: &crate::system::ComponentContext, ) { - let message = message.into_inner(); - match message { - Ok(msg) => { - // Unwrap should be safe here as we are guaranteed to have a value by construction - self.register( - self.pulled_log_offset.unwrap(), - msg.segment_flush_info, - ctx.receiver(), - ) - .await; - } - Err(e) => { - tracing::error!("Error flushing to S3: {:?}", e); - terminate_with_error(self.result_channel.take(), e.boxed(), ctx); - } - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, + }; + self.register( + self.pulled_log_offset.unwrap(), + output.segment_flush_info, + ctx, + ) + .await; } } @@ -665,26 +607,16 @@ impl Handler> for CompactOrchestrator message: TaskResult, ctx: &crate::system::ComponentContext, ) { - let message = message.into_inner(); - // Return execution state to the compaction manager - let result_channel = self - .result_channel - .take() - .expect("Invariant violation. Result channel is not set."); - - match message { - Ok(_) => { - let response = CompactionResponse { + self.terminate_with_result( + message + .into_inner() + .map_err(|e| e.into()) + .map(|_| CompactionResponse { id: self.id, compaction_job: self.compaction_job.clone(), message: "Compaction Complete".to_string(), - }; - let _ = result_channel.send(Ok(response)); - } - Err(e) => { - tracing::error!("Error registering compaction: {:?}", e); - terminate_with_error(Some(result_channel), Box::new(e), ctx); - } - } + }), + ctx, + ); } } diff --git a/rust/worker/src/execution/orchestration/count.rs b/rust/worker/src/execution/orchestration/count.rs index 54cbde8690b..8d91359b83c 100644 --- a/rust/worker/src/execution/orchestration/count.rs +++ b/rust/worker/src/execution/orchestration/count.rs @@ -1,332 +1,133 @@ -use crate::execution::dispatcher::Dispatcher; -use crate::execution::operator::{wrap, TaskResult}; -use crate::execution::operators::count_records::{ - CountRecordsError, CountRecordsInput, CountRecordsOperator, CountRecordsOutput, -}; -use crate::execution::operators::fetch_log::{FetchLogError, FetchLogOperator, FetchLogOutput}; -use crate::execution::orchestration::common::terminate_with_error; -use crate::sysdb::sysdb::{GetCollectionsError, GetSegmentsError}; -use crate::system::{Component, ComponentContext, ComponentHandle, Handler}; -use crate::{log::log::Log, sysdb::sysdb::SysDb, system::System}; use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::{Collection, CollectionUuid, Segment, SegmentType, SegmentUuid}; +use chroma_types::CollectionAndSegments; use thiserror::Error; -use tracing::Span; -use uuid::Uuid; +use tokio::sync::oneshot::{error::RecvError, Sender}; + +use crate::{ + execution::{ + dispatcher::Dispatcher, + operator::{wrap, TaskError, TaskMessage, TaskResult}, + operators::{ + count_records::{ + CountRecordsError, CountRecordsInput, CountRecordsOperator, CountRecordsOutput, + }, + fetch_log::{FetchLogError, FetchLogOperator, FetchLogOutput}, + }, + }, + system::{ChannelError, ComponentContext, ComponentHandle, Handler}, +}; -#[derive(Debug)] -pub(crate) struct CountQueryOrchestrator { - // Component Execution - system: System, - // Query state - metadata_segment_id: Uuid, - collection_id: CollectionUuid, - // State fetched or created for query execution - record_segment: Option, - collection: Option, - // Services - log: Box, - sysdb: Box, - dispatcher: ComponentHandle, - blockfile_provider: BlockfileProvider, - // Result channel - result_channel: Option>>>, - // Request version context - collection_version: u32, - log_position: u64, -} +use super::orchestrator::Orchestrator; #[derive(Error, Debug)] -enum CountQueryOrchestratorError { - #[error("Blockfile metadata segment with id: {0} not found")] - BlockfileMetadataSegmentNotFound(Uuid), - #[error("Get segments error: {0}")] - GetSegmentsError(#[from] GetSegmentsError), - #[error("Record segment not found for collection: {0}")] - RecordSegmentNotFound(CollectionUuid), - #[error("System Time Error")] - SystemTimeError(#[from] std::time::SystemTimeError), - #[error("Collection not found for id: {0}")] - CollectionNotFound(CollectionUuid), - #[error("Get collection error: {0}")] - GetCollectionError(#[from] GetCollectionsError), - #[error("Collection version mismatch")] - CollectionVersionMismatch, - #[error("Task dispatch failed")] - DispatchFailure, +pub enum CountError { + #[error("Error sending message through channel: {0}")] + Channel(#[from] ChannelError), + #[error("Error running Fetch Log Operator: {0}")] + FetchLog(#[from] FetchLogError), + #[error("Error running Count Record Operator: {0}")] + CountRecord(#[from] CountRecordsError), + #[error("Panic running task: {0}")] + Panic(String), + #[error("Error receiving final result: {0}")] + Result(#[from] RecvError), } -impl ChromaError for CountQueryOrchestratorError { +impl ChromaError for CountError { fn code(&self) -> ErrorCodes { match self { - CountQueryOrchestratorError::BlockfileMetadataSegmentNotFound(_) => { - ErrorCodes::NotFound - } - CountQueryOrchestratorError::GetSegmentsError(e) => e.code(), - CountQueryOrchestratorError::RecordSegmentNotFound(_) => ErrorCodes::NotFound, - CountQueryOrchestratorError::SystemTimeError(_) => ErrorCodes::Internal, - CountQueryOrchestratorError::CollectionNotFound(_) => ErrorCodes::NotFound, - CountQueryOrchestratorError::GetCollectionError(e) => e.code(), - CountQueryOrchestratorError::CollectionVersionMismatch => ErrorCodes::VersionMismatch, - CountQueryOrchestratorError::DispatchFailure => ErrorCodes::Internal, + CountError::Channel(e) => e.code(), + CountError::FetchLog(e) => e.code(), + CountError::CountRecord(e) => e.code(), + CountError::Panic(_) => ErrorCodes::Aborted, + CountError::Result(_) => ErrorCodes::Internal, } } } -impl CountQueryOrchestrator { - #[allow(clippy::too_many_arguments)] - pub(crate) fn new( - system: System, - metadata_segment_id: &Uuid, - collection_id: &CollectionUuid, - log: Box, - sysdb: Box, - dispatcher: ComponentHandle, - blockfile_provider: BlockfileProvider, - collection_version: u32, - log_position: u64, - ) -> Self { - Self { - system, - metadata_segment_id: *metadata_segment_id, - collection_id: *collection_id, - record_segment: None, - collection: None, - log, - sysdb, - dispatcher, - blockfile_provider, - result_channel: None, - collection_version, - log_position, - } - } - - async fn start(&mut self, ctx: &ComponentContext) { - println!("Starting Count Query Orchestrator"); - // Populate the orchestrator with the initial state - The Record Segment and the Collection - let metdata_segment = self - .get_metadata_segment_from_id( - self.sysdb.clone(), - &self.metadata_segment_id, - &self.collection_id, - ) - .await; - - let metadata_segment = match metdata_segment { - Ok(segment) => segment, - Err(e) => { - tracing::error!("Error getting metadata segment: {:?}", e); - terminate_with_error(self.result_channel.take(), e, ctx); - return; - } - }; - - let collection_id = metadata_segment.collection; - - let record_segment = self - .get_record_segment_from_collection_id(self.sysdb.clone(), &collection_id) - .await; - - let record_segment = match record_segment { - Ok(segment) => segment, - Err(e) => { - tracing::error!("Error getting record segment: {:?}", e); - terminate_with_error(self.result_channel.take(), e, ctx); - return; - } - }; - - let collection = match self - .get_collection_from_id(self.sysdb.clone(), &collection_id, ctx) - .await - { - Ok(collection) => collection, - Err(e) => { - tracing::error!("Error getting collection: {:?}", e); - terminate_with_error(self.result_channel.take(), e, ctx); - return; - } - }; - - // If the collection version does not match the request version then we terminate with an error - if collection.version as u32 != self.collection_version { - terminate_with_error( - self.result_channel.take(), - Box::new(CountQueryOrchestratorError::CollectionVersionMismatch), - ctx, - ); - return; +impl From> for CountError +where + E: Into, +{ + fn from(value: TaskError) -> Self { + match value { + TaskError::Panic(e) => CountError::Panic(e.unwrap_or_default()), + TaskError::TaskFailed(e) => e.into(), } - - self.record_segment = Some(record_segment); - self.collection = Some(collection); - self.fetch_log(ctx).await; } +} - // shared - async fn fetch_log(&mut self, ctx: &ComponentContext) { - println!("Count query orchestrator pulling logs"); - - let collection = self - .collection - .as_ref() - .expect("Invariant violation. Collection is not set before pull logs state."); - - let operator = FetchLogOperator { - log_client: self.log.clone(), - batch_size: 100, - // The collection log position is inclusive, and we want to start from the next log. - // Note that we query using the incoming log position this is critical for correctness. - start_log_offset_id: self.log_position as u32 + 1, - maximum_fetch_count: None, - collection_uuid: collection.collection_id, - }; +type CountOutput = usize; +type CountResult = Result; - let task = wrap(Box::new(operator), (), ctx.receiver()); - match self.dispatcher.send(task, Some(Span::current())).await { - Ok(_) => (), - Err(e) => { - // Log an error - this implies the dispatcher was dropped somehow - // and is likely fatal - println!("Error sending Count Query task: {:?}", e); - terminate_with_error( - self.result_channel.take(), - Box::new(CountQueryOrchestratorError::DispatchFailure), - ctx, - ); - } - } - } +#[derive(Debug)] +pub struct CountOrchestrator { + // Orchestrator parameters + blockfile_provider: BlockfileProvider, + dispatcher: ComponentHandle, + queue: usize, - // shared - async fn get_metadata_segment_from_id( - &self, - mut sysdb: Box, - metadata_segment_id: &Uuid, - collection_id: &CollectionUuid, - ) -> Result> { - let segments = sysdb - .get_segments( - Some(SegmentUuid(*metadata_segment_id)), - None, - None, - *collection_id, - ) - .await; - let segment = match segments { - Ok(segments) => { - if segments.is_empty() { - return Err(Box::new( - CountQueryOrchestratorError::BlockfileMetadataSegmentNotFound( - *metadata_segment_id, - ), - )); - } - segments[0].clone() - } - Err(e) => { - return Err(Box::new(CountQueryOrchestratorError::GetSegmentsError(e))); - } - }; + // Collection and segments + collection_and_segments: CollectionAndSegments, - if segment.r#type != SegmentType::BlockfileMetadata { - return Err(Box::new( - CountQueryOrchestratorError::BlockfileMetadataSegmentNotFound(*metadata_segment_id), - )); - } - Ok(segment) - } + // Fetch logs + fetch_log: FetchLogOperator, - // shared - async fn get_record_segment_from_collection_id( - &self, - mut sysdb: Box, - collection_id: &CollectionUuid, - ) -> Result> { - let segments = sysdb - .get_segments( - None, - Some(SegmentType::BlockfileRecord.into()), - None, - *collection_id, - ) - .await; + // Result channel + result_channel: Option>>, +} - match segments { - Ok(segments) => { - if segments.is_empty() { - return Err(Box::new( - CountQueryOrchestratorError::RecordSegmentNotFound(*collection_id), - )); - } - // Unwrap is safe as we know at least one segment exists from - // the check above - Ok(segments.into_iter().next().unwrap()) - } - Err(e) => Err(Box::new(CountQueryOrchestratorError::GetSegmentsError(e))), +impl CountOrchestrator { + pub(crate) fn new( + blockfile_provider: BlockfileProvider, + dispatcher: ComponentHandle, + queue: usize, + collection_and_segments: CollectionAndSegments, + fetch_log: FetchLogOperator, + ) -> Self { + Self { + blockfile_provider, + dispatcher, + collection_and_segments, + queue, + fetch_log, + result_channel: None, } } +} - // shared - async fn get_collection_from_id( - &self, - mut sysdb: Box, - collection_id: &CollectionUuid, - _ctx: &ComponentContext, - ) -> Result> { - let collections = sysdb - .get_collections(Some(*collection_id), None, None, None) - .await; +#[async_trait] +impl Orchestrator for CountOrchestrator { + type Output = CountOutput; + type Error = CountError; - match collections { - Ok(collections) => { - if collections.is_empty() { - return Err(Box::new(CountQueryOrchestratorError::CollectionNotFound( - *collection_id, - ))); - } - // Unwrap is safe as we know at least one collection exists from - // the check above - Ok(collections.into_iter().next().unwrap()) - } - Err(e) => Err(Box::new(CountQueryOrchestratorError::GetCollectionError(e))), - } + fn dispatcher(&self) -> ComponentHandle { + self.dispatcher.clone() } - /// Run the orchestrator and return the result. - /// # Note - /// Use this over spawning the component directly. This method will start the component and - /// wait for it to finish before returning the result. - pub(crate) async fn run(mut self) -> Result> { - let (tx, rx) = tokio::sync::oneshot::channel(); - self.result_channel = Some(tx); - let mut handle = self.system.clone().start_component(self); - let result = rx.await; - handle.stop(); - result.unwrap() + fn initial_tasks(&self, ctx: &ComponentContext) -> Vec { + vec![wrap(Box::new(self.fetch_log.clone()), (), ctx.receiver())] } -} -#[async_trait] -impl Component for CountQueryOrchestrator { - fn get_name() -> &'static str { - "Count Query Orchestrator" + fn queue_size(&self) -> usize { + self.queue } - fn queue_size(&self) -> usize { - 1000 // TODO: make this configurable + fn set_result_channel(&mut self, sender: Sender) { + self.result_channel = Some(sender) } - async fn on_start(&mut self, ctx: &crate::system::ComponentContext) -> () { - self.start(ctx).await; + fn take_result_channel(&mut self) -> Sender { + self.result_channel + .take() + .expect("The result channel should be set before take") } } #[async_trait] -impl Handler> for CountQueryOrchestrator { +impl Handler> for CountOrchestrator { type Result = (); async fn handle( @@ -334,37 +135,25 @@ impl Handler> for CountQueryOrchestrat message: TaskResult, ctx: &ComponentContext, ) { - let message = message.into_inner(); - match message { - Ok(logs) => { - let operator = CountRecordsOperator::new(); - let input = CountRecordsInput::new( - self.record_segment - .as_ref() - .expect("Expect segment") - .clone(), - self.blockfile_provider.clone(), - logs, - ); - let msg = wrap(operator, input, ctx.receiver()); - match self.dispatcher.send(msg, None).await { - Ok(_) => (), - Err(e) => { - // Log an error - this implies the dispatcher was dropped somehow - // and is likely fatal - println!("Error sending Count Query task: {:?}", e); - } - } - } - Err(e) => { - terminate_with_error(self.result_channel.take(), Box::new(e), ctx); - } - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, + }; + let task = wrap( + CountRecordsOperator::new(), + CountRecordsInput::new( + self.collection_and_segments.record_segment.clone(), + self.blockfile_provider.clone(), + output, + ), + ctx.receiver(), + ); + self.send(task, ctx).await; } } #[async_trait] -impl Handler> for CountQueryOrchestrator { +impl Handler> for CountOrchestrator { type Result = (); async fn handle( @@ -372,23 +161,12 @@ impl Handler> for CountQueryOr message: TaskResult, ctx: &ComponentContext, ) { - let message = message.into_inner(); - let msg = match message { - Ok(m) => m, - Err(e) => { - return terminate_with_error(self.result_channel.take(), Box::new(e), ctx); - } - }; - let channel = self - .result_channel - .take() - .expect("Expect channel to be present"); - match channel.send(Ok(msg.count)) { - Ok(_) => (), - Err(_) => { - // Log an error - this implied the listener was dropped - println!("[CountQueryOrchestrator] Result channel dropped before sending result"); - } - } + self.terminate_with_result( + message + .into_inner() + .map_err(|e| e.into()) + .map(|output| output.count), + ctx, + ); } } diff --git a/rust/worker/src/execution/orchestration/get.rs b/rust/worker/src/execution/orchestration/get.rs index 443b5be1b59..2b955e3fc78 100644 --- a/rust/worker/src/execution/orchestration/get.rs +++ b/rust/worker/src/execution/orchestration/get.rs @@ -1,38 +1,33 @@ +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; +use chroma_types::CollectionAndSegments; use thiserror::Error; -use tokio::sync::oneshot::{self, error::RecvError, Sender}; -use tonic::async_trait; -use tracing::Span; +use tokio::sync::oneshot::{error::RecvError, Sender}; use crate::{ execution::{ dispatcher::Dispatcher, - operator::{wrap, TaskError, TaskResult}, + operator::{wrap, TaskError, TaskMessage, TaskResult}, operators::{ fetch_log::{FetchLogError, FetchLogOperator, FetchLogOutput}, - fetch_segment::{FetchSegmentError, FetchSegmentOperator, FetchSegmentOutput}, filter::{FilterError, FilterInput, FilterOperator, FilterOutput}, limit::{LimitError, LimitInput, LimitOperator, LimitOutput}, - prefetch_record::{ - PrefetchRecordError, PrefetchRecordInput, PrefetchRecordOperator, - PrefetchRecordOutput, - }, + prefetch_record::{PrefetchRecordError, PrefetchRecordOperator, PrefetchRecordOutput}, projection::{ProjectionError, ProjectionInput, ProjectionOperator, ProjectionOutput}, }, - orchestration::common::terminate_with_error, }, - system::{ChannelError, Component, ComponentContext, ComponentHandle, Handler, System}, + system::{ChannelError, ComponentContext, ComponentHandle, Handler}, }; +use super::orchestrator::Orchestrator; + #[derive(Error, Debug)] pub enum GetError { #[error("Error sending message through channel: {0}")] Channel(#[from] ChannelError), #[error("Error running Fetch Log Operator: {0}")] FetchLog(#[from] FetchLogError), - #[error("Error running Fetch Segment Operator: {0}")] - FetchSegment(#[from] FetchSegmentError), #[error("Error running Filter Operator: {0}")] Filter(#[from] FilterError), #[error("Error running Limit Operator: {0}")] @@ -50,7 +45,6 @@ impl ChromaError for GetError { match self { GetError::Channel(e) => e.code(), GetError::FetchLog(e) => e.code(), - GetError::FetchSegment(e) => e.code(), GetError::Filter(e) => e.code(), GetError::Limit(e) => e.code(), GetError::Panic(_) => ErrorCodes::Aborted, @@ -81,62 +75,47 @@ type GetResult = Result; /// /// # Pipeline /// ```text -/// ┌────────────┐ -/// │ │ -/// ┌───────────┤ on_start ├────────────────┐ -/// │ │ │ │ -/// │ └────────────┘ │ -/// │ │ -/// ▼ ▼ -/// ┌────────────────────┐ ┌────────────────────────┐ -/// │ │ │ │ -/// │ FetchLogOperator │ │ FetchSegmentOperator │ -/// │ │ │ │ -/// └────────┬───────────┘ └────────────────┬───────┘ -/// │ │ -/// │ │ -/// │ ┌─────────────────────────────┐ │ -/// │ │ │ │ -/// └────►│ try_start_filter_operator │◄────┘ -/// │ │ -/// └────────────┬────────────────┘ -/// │ -/// ▼ -/// ┌───────────────────┐ -/// │ │ -/// │ FilterOperator │ -/// │ │ -/// └─────────┬─────────┘ -/// │ -/// ▼ -/// ┌─────────────────┐ -/// │ │ -/// │ LimitOperator │ -/// │ │ -/// └────────┬────────┘ -/// │ -/// ▼ -/// ┌──────────────────────┐ -/// │ │ -/// │ ProjectionOperator │ -/// │ │ -/// └──────────┬───────────┘ -/// │ -/// ▼ -/// ┌──────────────────┐ -/// │ │ -/// │ result_channel │ -/// │ │ -/// └──────────────────┘ +/// ┌────────────┐ +/// │ │ +/// │ on_start │ +/// │ │ +/// └──────┬─────┘ +/// │ +/// ▼ +/// ┌────────────────────┐ +/// │ │ +/// │ FetchLogOperator │ +/// │ │ +/// └─────────┬──────────┘ +/// │ +/// ▼ +/// ┌───────────────────┐ +/// │ │ +/// │ FilterOperator │ +/// │ │ +/// └─────────┬─────────┘ +/// │ +/// ▼ +/// ┌─────────────────┐ +/// │ │ +/// │ LimitOperator │ +/// │ │ +/// └────────┬────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ │ +/// │ ProjectionOperator │ +/// │ │ +/// └──────────┬───────────┘ +/// │ +/// ▼ +/// ┌──────────────────┐ +/// │ │ +/// │ result_channel │ +/// │ │ +/// └──────────────────┘ /// ``` -/// -/// # State tracking -/// As suggested by the pipeline diagram above, the orchestrator only need to -/// keep track of the outputs from `FetchLogOperator` and `FetchSegmentOperator`. -/// The orchestrator invokes `try_start_filter_operator` when it receives output -/// from either operators, and if both outputs are present it composes the input -/// for `FilterOperator` and proceeds with execution. The outputs of other -/// operators are directly forwarded without being tracked by the orchestrator. #[derive(Debug)] pub struct GetOrchestrator { // Orchestrator parameters @@ -144,13 +123,14 @@ pub struct GetOrchestrator { dispatcher: ComponentHandle, queue: usize, - // Fetch logs and segments + // Collection segments + collection_and_segments: CollectionAndSegments, + + // Fetch logs fetch_log: FetchLogOperator, - fetch_segment: FetchSegmentOperator, - // Fetch output - fetch_log_output: Option, - fetch_segment_output: Option, + // Fetched logs + fetched_logs: Option, // Pipelined operators filter: FilterOperator, @@ -167,8 +147,8 @@ impl GetOrchestrator { blockfile_provider: BlockfileProvider, dispatcher: ComponentHandle, queue: usize, + collection_and_segments: CollectionAndSegments, fetch_log: FetchLogOperator, - fetch_segment: FetchSegmentOperator, filter: FilterOperator, limit: LimitOperator, projection: ProjectionOperator, @@ -177,82 +157,42 @@ impl GetOrchestrator { blockfile_provider, dispatcher, queue, + collection_and_segments, fetch_log, - fetch_segment, - fetch_log_output: None, - fetch_segment_output: None, + fetched_logs: None, filter, limit, projection, result_channel: None, } } +} - pub async fn run(mut self, system: System) -> GetResult { - let (tx, rx) = oneshot::channel(); - self.result_channel = Some(tx); - let mut handle = system.start_component(self); - let result = rx.await; - handle.stop(); - result? - } - - fn terminate_with_error(&mut self, ctx: &ComponentContext, err: E) - where - E: Into, - { - let get_err = err.into(); - tracing::error!("Error running orchestrator: {}", &get_err); - terminate_with_error(self.result_channel.take(), get_err, ctx); - } +#[async_trait] +impl Orchestrator for GetOrchestrator { + type Output = GetOutput; + type Error = GetError; - /// Try to start the filter operator once both `FetchLogOperator` and `FetchSegmentOperator` completes - async fn try_start_filter_operator(&mut self, ctx: &ComponentContext) { - if let (Some(logs), Some(segments)) = ( - self.fetch_log_output.as_ref(), - self.fetch_segment_output.as_ref(), - ) { - let task = wrap( - Box::new(self.filter.clone()), - FilterInput { - logs: logs.clone(), - blockfile_provider: self.blockfile_provider.clone(), - metadata_segment: segments.metadata_segment.clone(), - record_segment: segments.record_segment.clone(), - }, - ctx.receiver(), - ); - if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - } - } + fn dispatcher(&self) -> ComponentHandle { + self.dispatcher.clone() } -} -#[async_trait] -impl Component for GetOrchestrator { - fn get_name() -> &'static str { - "Get Orchestrator" + fn initial_tasks(&self, ctx: &ComponentContext) -> Vec { + vec![wrap(Box::new(self.fetch_log.clone()), (), ctx.receiver())] } fn queue_size(&self) -> usize { self.queue } - async fn on_start(&mut self, ctx: &ComponentContext) { - let log_task = wrap(Box::new(self.fetch_log.clone()), (), ctx.receiver()); - let segment_task = wrap(Box::new(self.fetch_segment.clone()), (), ctx.receiver()); - if let Err(err) = self.dispatcher.send(log_task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - return; - } else if let Err(err) = self - .dispatcher - .send(segment_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - return; - } + fn set_result_channel(&mut self, sender: Sender) { + self.result_channel = Some(sender) + } + + fn take_result_channel(&mut self) -> Sender { + self.result_channel + .take() + .expect("The result channel should be set before take") } } @@ -265,36 +205,24 @@ impl Handler> for GetOrchestrator { message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; - self.fetch_log_output = Some(output); - self.try_start_filter_operator(ctx).await; - } -} -#[async_trait] -impl Handler> for GetOrchestrator { - type Result = (); + self.fetched_logs = Some(output.clone()); - async fn handle( - &mut self, - message: TaskResult, - ctx: &ComponentContext, - ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } - }; - self.fetch_segment_output = Some(output); - self.try_start_filter_operator(ctx).await; + let task = wrap( + Box::new(self.filter.clone()), + FilterInput { + logs: output, + blockfile_provider: self.blockfile_provider.clone(), + metadata_segment: self.collection_and_segments.metadata_segment.clone(), + record_segment: self.collection_and_segments.record_segment.clone(), + }, + ctx.receiver(), + ); + self.send(task, ctx).await; } } @@ -307,36 +235,26 @@ impl Handler> for GetOrchestrator { message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; let task = wrap( Box::new(self.limit.clone()), LimitInput { logs: self - .fetch_log_output + .fetched_logs .as_ref() .expect("FetchLogOperator should have finished already") .clone(), blockfile_provider: self.blockfile_provider.clone(), - record_segment: self - .fetch_segment_output - .as_ref() - .expect("FetchSegmentOperator should have finished already") - .record_segment - .clone(), + record_segment: self.collection_and_segments.record_segment.clone(), log_offset_ids: output.log_offset_ids, compact_offset_ids: output.compact_offset_ids, }, ctx.receiver(), ); - if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - } + self.send(task, ctx).await; } } @@ -349,64 +267,35 @@ impl Handler> for GetOrchestrator { message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, + }; + + let input = ProjectionInput { + logs: self + .fetched_logs + .as_ref() + .expect("FetchLogOperator should have finished already") + .clone(), + blockfile_provider: self.blockfile_provider.clone(), + record_segment: self.collection_and_segments.record_segment.clone(), + offset_ids: output.offset_ids.iter().collect(), }; // Prefetch records before projection let prefetch_task = wrap( Box::new(PrefetchRecordOperator {}), - PrefetchRecordInput { - logs: self - .fetch_log_output - .as_ref() - .expect("FetchLogOperator should have finished already") - .clone(), - blockfile_provider: self.blockfile_provider.clone(), - record_segment: self - .fetch_segment_output - .as_ref() - .expect("FetchSegmentOperator should have finished already") - .record_segment - .clone(), - offset_ids: output.offset_ids.iter().collect(), - }, + input.clone(), ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(prefetch_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } - let task = wrap( - Box::new(self.projection.clone()), - ProjectionInput { - logs: self - .fetch_log_output - .as_ref() - .expect("FetchLogOperator should have finished already") - .clone(), - blockfile_provider: self.blockfile_provider.clone(), - record_segment: self - .fetch_segment_output - .as_ref() - .expect("FetchSegmentOperator should have finished already") - .record_segment - .clone(), - offset_ids: output.offset_ids.into_iter().collect(), - }, - ctx.receiver(), - ); - if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); + if !self.send(prefetch_task, ctx).await { + return; } + + let task = wrap(Box::new(self.projection.clone()), input, ctx.receiver()); + self.send(task, ctx).await; } } @@ -432,17 +321,6 @@ impl Handler> for GetOrchestrator message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } - }; - if let Some(chan) = self.result_channel.take() { - if chan.send(Ok(output)).is_err() { - tracing::error!("Error sending final result"); - }; - } + self.terminate_with_result(message.into_inner().map_err(|e| e.into()), ctx); } } diff --git a/rust/worker/src/execution/orchestration/knn.rs b/rust/worker/src/execution/orchestration/knn.rs index dfc533c6690..2ffcd26ee1a 100644 --- a/rust/worker/src/execution/orchestration/knn.rs +++ b/rust/worker/src/execution/orchestration/knn.rs @@ -1,12 +1,11 @@ +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; -use tokio::sync::oneshot::{self, Sender}; -use tonic::async_trait; -use tracing::Span; +use tokio::sync::oneshot::Sender; use crate::{ execution::{ dispatcher::Dispatcher, - operator::{wrap, TaskResult}, + operator::{wrap, TaskMessage, TaskResult}, operators::{ knn::{KnnOperator, RecordDistance}, knn_hnsw::{KnnHnswError, KnnHnswInput, KnnHnswOutput}, @@ -20,15 +19,17 @@ use crate::{ PrefetchRecordOutput, }, }, - orchestration::common::terminate_with_error, }, - system::{Component, ComponentContext, ComponentHandle, Handler, System}, + system::{ComponentContext, ComponentHandle, Handler}, }; -use super::knn_filter::{KnnError, KnnFilterOutput, KnnResult}; +use super::{ + knn_filter::{KnnError, KnnFilterOutput, KnnOutput, KnnResult}, + orchestrator::Orchestrator, +}; -/// The `knn` module contains two orchestrator: `KnnFilterOrchestrator` and `KnnOrchestrator`. -/// When used together, they carry out the evaluation of a `.query(...)` query +/// The `KnnOrchestrator` finds the nearest neighbor of a target embedding given the search domain. +/// When used together with `KnnFilterOrchestrator`, they evaluate a `.query(...)` query /// for the user. We breakdown the evaluation into two parts because a `.query(...)` /// is inherently multiple queries sharing the same filter criteria. Thus we first evaluate /// the filter criteria with `KnnFilterOrchestrator`. Then we spawn a `KnnOrchestrator` for each @@ -38,53 +39,23 @@ use super::knn_filter::{KnnError, KnnFilterOutput, KnnResult}; /// /// # Pipeline /// ```text -/// │ -/// │ -/// │ -/// ┌──────────────────────────── │ ───────────────────────────────┐ -/// │ ▼ │ -/// │ ┌────────────┐ KnnFilterOrchestrator │ -/// │ │ │ │ -/// │ ┌───────────┤ on_start ├────────────────┐ │ -/// │ │ │ │ │ │ -/// │ │ └────────────┘ │ │ -/// │ │ │ │ -/// │ ▼ ▼ │ -/// │ ┌────────────────────┐ ┌────────────────────────┐ │ -/// │ │ │ │ │ │ -/// │ │ FetchLogOperator │ │ FetchSegmentOperator │ │ -/// │ │ │ │ │ │ -/// │ └────────┬───────────┘ └────────────────┬───────┘ │ -/// │ │ │ │ -/// │ │ │ │ -/// │ │ ┌─────────────────────────────┐ │ │ -/// │ │ │ │ │ │ -/// │ └────►│ try_start_filter_operator │◄────┘ │ -/// │ │ │ │ -/// │ └────────────┬────────────────┘ │ -/// │ │ │ -/// │ ▼ │ -/// │ ┌───────────────────┐ │ -/// │ │ │ │ -/// │ │ FilterOperator │ │ -/// │ │ │ │ -/// │ └─────────┬─────────┘ │ -/// │ │ │ -/// │ ▼ │ -/// │ ┌──────────────────┐ │ -/// │ │ │ │ -/// │ │ result_channel │ │ -/// │ │ │ │ -/// │ └────────┬─────────┘ │ -/// │ │ │ -/// └──────────────────────────── │ ───────────────────────────────┘ -/// │ -/// │ -/// │ -/// ┌──────────────────────────────────┴─────────────────────────────────────┐ -/// │ │ -/// │ ... One branch per embedding ... │ -/// │ │ +/// │ +/// │ +/// │ +/// │ +/// ▼ +/// ┌───────────────────────┐ +/// │ │ +/// │ KnnFilterOrchestrator │ +/// │ │ +/// └───────────┬───────────┘ +/// │ +/// │ +/// │ +/// ┌──────────────────────────────────┴─────────────────────────────────────┐ +/// │ │ +/// │ ... One branch per embedding ... │ +/// │ │ /// ┌──────────────────── │ ─────────────────────┐ ┌──────────────────── │ ─────────────────────┐ /// │ ▼ │ │ ▼ │ /// │ ┌────────────┐ KnnOrchestrator │ │ ┌────────────┐ KnnOrchestrator │ @@ -129,27 +100,18 @@ use super::knn_filter::{KnnError, KnnFilterOutput, KnnResult}; /// │ └────────┬─────────┘ │ │ └────────┬─────────┘ │ /// │ │ │ │ │ │ /// └──────────────────── │ ─────────────────────┘ └──────────────────── │ ─────────────────────┘ -/// │ │ -/// │ │ -/// │ │ -/// │ ┌────────────────┐ │ -/// │ │ │ │ -/// └──────────────────────────►│ try_join_all │◄──────────────────────────┘ -/// │ │ -/// └───────┬────────┘ -/// │ -/// │ -/// ▼ +/// │ │ +/// │ │ +/// │ │ +/// │ ┌────────────────┐ │ +/// │ │ │ │ +/// └──────────────────────────►│ try_join_all │◄──────────────────────────┘ +/// │ │ +/// └───────┬────────┘ +/// │ +/// │ +/// ▼ /// ``` -/// -/// # State tracking -/// Similar to the `GetOrchestrator`, the `KnnFilterOrchestrator` need to keep track of the outputs from -/// `FetchLogOperator` and `FetchSegmentOperator`. For `KnnOrchestrator`, it needs to track the outputs from -/// `KnnLogOperator` and `KnnHnswOperator`. It invokes `try_start_knn_merge_operator` when it receives outputs -/// from either operators, and if both outputs are present it composes the input for `KnnMergeOperator` and -/// proceeds with execution. The outputs of other operators are directly forwarded without being tracked -/// by the orchestrator. - #[derive(Debug)] pub struct KnnOrchestrator { // Orchestrator parameters @@ -185,6 +147,11 @@ impl KnnOrchestrator { knn_projection: KnnProjectionOperator, ) -> Self { let fetch = knn.fetch; + let knn_segment_distances = if knn_filter_output.hnsw_reader.is_none() { + Some(Vec::new()) + } else { + None + }; Self { blockfile_provider, dispatcher, @@ -192,31 +159,13 @@ impl KnnOrchestrator { knn_filter_output, knn, knn_log_distances: None, - knn_segment_distances: None, + knn_segment_distances, merge: KnnMergeOperator { fetch }, knn_projection, result_channel: None, } } - pub async fn run(mut self, system: System) -> KnnResult { - let (tx, rx) = oneshot::channel(); - self.result_channel = Some(tx); - let mut handle = system.start_component(self); - let result = rx.await; - handle.stop(); - result? - } - - fn terminate_with_error(&mut self, ctx: &ComponentContext, err: E) - where - E: Into, - { - let knn_err = err.into(); - tracing::error!("Error running orchestrator: {}", &knn_err); - terminate_with_error(self.result_channel.take(), knn_err, ctx); - } - async fn try_start_knn_merge_operator(&mut self, ctx: &ComponentContext) { if let (Some(log_distances), Some(segment_distances)) = ( self.knn_log_distances.as_ref(), @@ -230,24 +179,23 @@ impl KnnOrchestrator { }, ctx.receiver(), ); - if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - } + self.send(task, ctx).await; } } } #[async_trait] -impl Component for KnnOrchestrator { - fn get_name() -> &'static str { - "Knn Orchestrator" - } +impl Orchestrator for KnnOrchestrator { + type Output = KnnOutput; + type Error = KnnError; - fn queue_size(&self) -> usize { - self.queue + fn dispatcher(&self) -> ComponentHandle { + self.dispatcher.clone() } - async fn on_start(&mut self, ctx: &ComponentContext) { + fn initial_tasks(&self, ctx: &ComponentContext) -> Vec { + let mut tasks = Vec::new(); + let knn_log_task = wrap( Box::new(self.knn.clone()), KnnLogInput { @@ -259,14 +207,7 @@ impl Component for KnnOrchestrator { }, ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(knn_log_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - return; - } + tasks.push(knn_log_task); if let Some(hnsw_reader) = self.knn_filter_output.hnsw_reader.as_ref().cloned() { let knn_segment_task = wrap( @@ -282,17 +223,24 @@ impl Component for KnnOrchestrator { }, ctx.receiver(), ); - - if let Err(err) = self - .dispatcher - .send(knn_segment_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } - } else { - self.knn_segment_distances = Some(Vec::new()) + tasks.push(knn_segment_task); } + + tasks + } + + fn queue_size(&self) -> usize { + self.queue + } + + fn set_result_channel(&mut self, sender: Sender) { + self.result_channel = Some(sender) + } + + fn take_result_channel(&mut self) -> Sender { + self.result_channel + .take() + .expect("The result channel should be set before take") } } @@ -305,12 +253,9 @@ impl Handler> for KnnOrchestrator { message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; self.knn_log_distances = Some(output.record_distances); self.try_start_knn_merge_operator(ctx).await; @@ -326,12 +271,9 @@ impl Handler> for KnnOrchestrator { message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; self.knn_segment_distances = Some(output.record_distances); self.try_start_knn_merge_operator(ctx).await; @@ -366,13 +308,7 @@ impl Handler> for KnnOrchestrator { }, ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(prefetch_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } + self.send(prefetch_task, ctx).await; let projection_task = wrap( Box::new(self.knn_projection.clone()), @@ -384,13 +320,7 @@ impl Handler> for KnnOrchestrator { }, ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(projection_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } + self.send(projection_task, ctx).await; } } @@ -416,17 +346,6 @@ impl Handler> for KnnOrchest message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } - }; - if let Some(chan) = self.result_channel.take() { - if chan.send(Ok(output)).is_err() { - tracing::error!("Error sending final result"); - }; - } + self.terminate_with_result(message.into_inner().map_err(|e| e.into()), ctx); } } diff --git a/rust/worker/src/execution/orchestration/knn_filter.rs b/rust/worker/src/execution/orchestration/knn_filter.rs index f634e11c7a5..54632c3755b 100644 --- a/rust/worker/src/execution/orchestration/knn_filter.rs +++ b/rust/worker/src/execution/orchestration/knn_filter.rs @@ -1,20 +1,18 @@ +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_distance::DistanceFunction; use chroma_error::{ChromaError, ErrorCodes}; use chroma_index::hnsw_provider::HnswIndexProvider; -use chroma_types::Segment; +use chroma_types::{CollectionAndSegments, Segment}; use thiserror::Error; -use tokio::sync::oneshot::{self, error::RecvError, Sender}; -use tonic::async_trait; -use tracing::Span; +use tokio::sync::oneshot::{error::RecvError, Sender}; use crate::{ execution::{ dispatcher::Dispatcher, - operator::{wrap, TaskError, TaskResult}, + operator::{wrap, TaskError, TaskMessage, TaskResult}, operators::{ fetch_log::{FetchLogError, FetchLogOperator, FetchLogOutput}, - fetch_segment::{FetchSegmentError, FetchSegmentOperator, FetchSegmentOutput}, filter::{FilterError, FilterInput, FilterOperator, FilterOutput}, knn_hnsw::KnnHnswError, knn_log::KnnLogError, @@ -23,7 +21,6 @@ use crate::{ spann_centers_search::SpannCentersSearchError, spann_fetch_pl::SpannFetchPlError, }, - orchestration::common::terminate_with_error, }, segment::{ distributed_hnsw_segment::{ @@ -31,19 +28,17 @@ use crate::{ }, utils::distance_function_from_segment, }, - system::{ChannelError, Component, ComponentContext, ComponentHandle, Handler, System}, + system::{ChannelError, ComponentContext, ComponentHandle, Handler}, }; +use super::orchestrator::Orchestrator; + #[derive(Error, Debug)] pub enum KnnError { #[error("Error sending message through channel: {0}")] Channel(#[from] ChannelError), - #[error("Empty collection")] - EmptyCollection, #[error("Error running Fetch Log Operator: {0}")] FetchLog(#[from] FetchLogError), - #[error("Error running Fetch Segment Operator: {0}")] - FetchSegment(#[from] FetchSegmentError), #[error("Error running Filter Operator: {0}")] Filter(#[from] FilterError), #[error("Error creating hnsw segment reader: {0}")] @@ -74,9 +69,7 @@ impl ChromaError for KnnError { fn code(&self) -> ErrorCodes { match self { KnnError::Channel(e) => e.code(), - KnnError::EmptyCollection => ErrorCodes::Internal, KnnError::FetchLog(e) => e.code(), - KnnError::FetchSegment(e) => e.code(), KnnError::Filter(e) => e.code(), KnnError::HnswReader(e) => e.code(), KnnError::KnnLog(e) => e.code(), @@ -118,6 +111,38 @@ pub struct KnnFilterOutput { type KnnFilterResult = Result; +/// The `KnnFilterOrchestrator` chains a sequence of operators in sequence to evaluate +/// the first half of a `.query(...)` query from the user +/// +/// # Pipeline +/// ```text +/// ┌────────────┐ +/// │ │ +/// │ on_start │ +/// │ │ +/// └──────┬─────┘ +/// │ +/// ▼ +/// ┌────────────────────┐ +/// │ │ +/// │ FetchLogOperator │ +/// │ │ +/// └─────────┬──────────┘ +/// │ +/// ▼ +/// ┌───────────────────┐ +/// │ │ +/// │ FilterOperator │ +/// │ │ +/// └─────────┬─────────┘ +/// │ +/// ▼ +/// ┌──────────────────┐ +/// │ │ +/// │ result_channel │ +/// │ │ +/// └──────────────────┘ +/// ``` #[derive(Debug)] pub struct KnnFilterOrchestrator { // Orchestrator parameters @@ -126,13 +151,14 @@ pub struct KnnFilterOrchestrator { hnsw_provider: HnswIndexProvider, queue: usize, - // Fetch logs and segments + // Collection segments + collection_and_segments: CollectionAndSegments, + + // Fetch logs fetch_log: FetchLogOperator, - fetch_segment: FetchSegmentOperator, - // Fetch output - fetch_log_output: Option, - fetch_segment_output: Option, + // Fetched logs + fetched_logs: Option, // Pipelined operators filter: FilterOperator, @@ -147,8 +173,8 @@ impl KnnFilterOrchestrator { dispatcher: ComponentHandle, hnsw_provider: HnswIndexProvider, queue: usize, + collection_and_segments: CollectionAndSegments, fetch_log: FetchLogOperator, - fetch_segment: FetchSegmentOperator, filter: FilterOperator, ) -> Self { Self { @@ -156,127 +182,70 @@ impl KnnFilterOrchestrator { dispatcher, hnsw_provider, queue, + collection_and_segments, fetch_log, - fetch_segment, - fetch_log_output: None, - fetch_segment_output: None, + fetched_logs: None, filter, result_channel: None, } } +} - pub async fn run(mut self, system: System) -> KnnFilterResult { - let (tx, rx) = oneshot::channel(); - self.result_channel = Some(tx); - let mut handle = system.start_component(self); - let result = rx.await; - handle.stop(); - result? - } - - fn terminate_with_error(&mut self, ctx: &ComponentContext, err: E) - where - E: Into, - { - let knn_err = err.into(); - tracing::error!("Error running orchestrator: {}", &knn_err); - terminate_with_error(self.result_channel.take(), knn_err, ctx); - } +#[async_trait] +impl Orchestrator for KnnFilterOrchestrator { + type Output = KnnFilterOutput; + type Error = KnnError; - async fn try_start_filter_operator(&mut self, ctx: &ComponentContext) { - if let (Some(logs), Some(segments)) = ( - self.fetch_log_output.as_ref(), - self.fetch_segment_output.as_ref(), - ) { - let task = wrap( - Box::new(self.filter.clone()), - FilterInput { - logs: logs.clone(), - blockfile_provider: self.blockfile_provider.clone(), - metadata_segment: segments.metadata_segment.clone(), - record_segment: segments.record_segment.clone(), - }, - ctx.receiver(), - ); - if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - } - } + fn dispatcher(&self) -> ComponentHandle { + self.dispatcher.clone() } -} -#[async_trait] -impl Component for KnnFilterOrchestrator { - fn get_name() -> &'static str { - "Knn Filter Orchestrator" + fn initial_tasks(&self, ctx: &ComponentContext) -> Vec { + vec![wrap(Box::new(self.fetch_log.clone()), (), ctx.receiver())] } fn queue_size(&self) -> usize { self.queue } - async fn on_start(&mut self, ctx: &ComponentContext) { - let log_task = wrap(Box::new(self.fetch_log.clone()), (), ctx.receiver()); - let segment_task = wrap(Box::new(self.fetch_segment.clone()), (), ctx.receiver()); - if let Err(err) = self.dispatcher.send(log_task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - } else if let Err(err) = self - .dispatcher - .send(segment_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } + fn set_result_channel(&mut self, sender: Sender) { + self.result_channel = Some(sender) } -} - -#[async_trait] -impl Handler> for KnnFilterOrchestrator { - type Result = (); - async fn handle( - &mut self, - message: TaskResult, - ctx: &ComponentContext, - ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } - }; - self.fetch_log_output = Some(output); - self.try_start_filter_operator(ctx).await; + fn take_result_channel(&mut self) -> Sender { + self.result_channel + .take() + .expect("The result channel should be set before take") } } #[async_trait] -impl Handler> for KnnFilterOrchestrator { +impl Handler> for KnnFilterOrchestrator { type Result = (); async fn handle( &mut self, - message: TaskResult, + message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; - // If dimension is not set and segment is uninitialized, we assume - // this is a query on empty collection, so we return early here - if output.collection.dimension.is_none() && output.vector_segment.file_path.is_empty() { - self.terminate_with_error(ctx, KnnError::EmptyCollection); - return; - } + self.fetched_logs = Some(output.clone()); - self.fetch_segment_output = Some(output); - self.try_start_filter_operator(ctx).await; + let task = wrap( + Box::new(self.filter.clone()), + FilterInput { + logs: output, + blockfile_provider: self.blockfile_provider.clone(), + metadata_segment: self.collection_and_segments.metadata_segment.clone(), + record_segment: self.collection_and_segments.record_segment.clone(), + }, + ctx.receiver(), + ); + self.send(task, ctx).await; } } @@ -289,33 +258,30 @@ impl Handler> for KnnFilterOrchestrator { message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; - let segments = self - .fetch_segment_output - .take() - .expect("FetchSegmentOperator should have finished already"); - let collection_dimension = match segments.collection.dimension { - Some(dimension) => dimension as u32, - None => { - self.terminate_with_error(ctx, KnnError::NoCollectionDimension); - return; - } + let collection_dimension = match self.ok_or_terminate( + self.collection_and_segments + .collection + .dimension + .ok_or(KnnError::NoCollectionDimension), + ctx, + ) { + Some(dim) => dim as u32, + None => return, }; - let distance_function = match distance_function_from_segment(&segments.vector_segment) { - Ok(distance_function) => distance_function, - Err(_) => { - self.terminate_with_error(ctx, KnnError::InvalidDistanceFunction); - return; - } + let distance_function = match self.ok_or_terminate( + distance_function_from_segment(&self.collection_and_segments.vector_segment) + .map_err(|_| KnnError::InvalidDistanceFunction), + ctx, + ) { + Some(distance_function) => distance_function, + None => return, }; let hnsw_reader = match DistributedHNSWSegmentReader::from_segment( - &segments.vector_segment, + &self.collection_and_segments.vector_segment, collection_dimension as usize, self.hnsw_provider.clone(), ) @@ -327,29 +293,23 @@ impl Handler> for KnnFilterOrchestrator { } Err(err) => { - self.terminate_with_error(ctx, *err); + self.terminate_with_result(Err((*err).into()), ctx); return; } }; - if let Some(chan) = self.result_channel.take() { - if chan - .send(Ok(KnnFilterOutput { - logs: self - .fetch_log_output - .take() - .expect("FetchLogOperator should have finished already"), - distance_function, - filter_output: output, - hnsw_reader, - record_segment: segments.record_segment, - vector_segment: segments.vector_segment, - dimension: collection_dimension as usize, - })) - .is_err() - { - tracing::error!("Error sending final result"); - }; - } + let output = KnnFilterOutput { + logs: self + .fetched_logs + .take() + .expect("FetchLogOperator should have finished already"), + distance_function, + filter_output: output, + hnsw_reader, + record_segment: self.collection_and_segments.record_segment.clone(), + vector_segment: self.collection_and_segments.vector_segment.clone(), + dimension: collection_dimension as usize, + }; + self.terminate_with_result(Ok(output), ctx); } } diff --git a/rust/worker/src/execution/orchestration/mod.rs b/rust/worker/src/execution/orchestration/mod.rs index d9b83d6e48a..58a9c0eb942 100644 --- a/rust/worker/src/execution/orchestration/mod.rs +++ b/rust/worker/src/execution/orchestration/mod.rs @@ -1,4 +1,3 @@ -mod common; mod compact; mod count; mod spann_knn; @@ -8,3 +7,4 @@ pub(crate) use count::*; pub mod get; pub mod knn; pub mod knn_filter; +pub mod orchestrator; diff --git a/rust/worker/src/execution/orchestration/orchestrator.rs b/rust/worker/src/execution/orchestration/orchestrator.rs new file mode 100644 index 00000000000..ddc96f93c26 --- /dev/null +++ b/rust/worker/src/execution/orchestration/orchestrator.rs @@ -0,0 +1,111 @@ +use core::fmt::Debug; +use std::any::type_name; + +use async_trait::async_trait; +use chroma_error::ChromaError; +use tokio::sync::oneshot::{self, error::RecvError, Sender}; +use tracing::Span; + +use crate::{ + execution::{dispatcher::Dispatcher, operator::TaskMessage}, + system::{ChannelError, Component, ComponentContext, ComponentHandle, System}, +}; + +#[async_trait] +pub trait Orchestrator: Debug + Send + Sized + 'static { + type Output: Send; + type Error: ChromaError + From + From; + + /// Returns the handle of the dispatcher + fn dispatcher(&self) -> ComponentHandle; + + /// Returns a vector of starting tasks that should be run in sequence + fn initial_tasks(&self, ctx: &ComponentContext) -> Vec; + + fn name() -> &'static str { + type_name::() + } + + fn queue_size(&self) -> usize { + 1000 + } + + /// Runs the orchestrator in a system and returns the result + async fn run(mut self, system: System) -> Result { + let (tx, rx) = oneshot::channel(); + self.set_result_channel(tx); + let mut handle = system.start_component(self); + let res = rx.await; + handle.stop(); + res? + } + + /// Sends a task to the dispatcher and return whether the task is successfully sent + async fn send(&mut self, task: TaskMessage, ctx: &ComponentContext) -> bool { + let res = self.dispatcher().send(task, Some(Span::current())).await; + self.ok_or_terminate(res, ctx).is_some() + } + + /// Sets the result channel of the orchestrator + fn set_result_channel(&mut self, sender: Sender>); + + /// Takes the result channel of the orchestrator. The channel should have been set when this is invoked + fn take_result_channel(&mut self) -> Sender>; + + /// Terminate the orchestrator with a result + fn terminate_with_result( + &mut self, + res: Result, + ctx: &ComponentContext, + ) { + let cancel = if let Err(err) = &res { + tracing::error!("Error running {}: {}", Self::name(), err); + true + } else { + false + }; + + let channel = self.take_result_channel(); + if channel.send(res).is_err() { + tracing::error!("Error sending result for {}", Self::name()); + }; + + if cancel { + ctx.cancellation_token.cancel(); + } + } + + /// Terminate the orchestrator if the result is an error. Returns the output if any. + fn ok_or_terminate>( + &mut self, + res: Result, + ctx: &ComponentContext, + ) -> Option { + match res { + Ok(output) => Some(output), + Err(error) => { + self.terminate_with_result(Err(error.into()), ctx); + None + } + } + } +} + +#[async_trait] +impl Component for O { + fn get_name() -> &'static str { + Self::name() + } + + fn queue_size(&self) -> usize { + self.queue_size() + } + + async fn start(&mut self, ctx: &ComponentContext) { + for task in self.initial_tasks(ctx) { + if !self.send(task, ctx).await { + break; + } + } + } +} diff --git a/rust/worker/src/execution/orchestration/spann_knn.rs b/rust/worker/src/execution/orchestration/spann_knn.rs index da4422be14b..a190055b800 100644 --- a/rust/worker/src/execution/orchestration/spann_knn.rs +++ b/rust/worker/src/execution/orchestration/spann_knn.rs @@ -1,14 +1,13 @@ +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_distance::{normalize, DistanceFunction}; use chroma_index::hnsw_provider::HnswIndexProvider; -use tokio::sync::oneshot::{self, Sender}; -use tonic::async_trait; -use tracing::Span; +use tokio::sync::oneshot::Sender; use crate::{ execution::{ dispatcher::Dispatcher, - operator::{wrap, TaskResult}, + operator::{wrap, TaskMessage, TaskResult}, operators::{ knn::{KnnOperator, RecordDistance}, knn_log::{KnnLogError, KnnLogInput, KnnLogOutput}, @@ -31,13 +30,15 @@ use crate::{ SpannKnnMergeError, SpannKnnMergeInput, SpannKnnMergeOperator, SpannKnnMergeOutput, }, }, - orchestration::common::terminate_with_error, }, segment::spann_segment::SpannSegmentReaderContext, - system::{Component, ComponentContext, ComponentHandle, Handler, System}, + system::{ComponentContext, ComponentHandle, Handler}, }; -use super::knn_filter::{KnnError, KnnFilterOutput, KnnResult}; +use super::{ + knn_filter::{KnnError, KnnFilterOutput, KnnOutput, KnnResult}, + orchestrator::Orchestrator, +}; // TODO(Sanket): Make these configurable. const RNG_FACTOR: f32 = 1.0; @@ -127,24 +128,6 @@ impl SpannKnnOrchestrator { } } - pub async fn run(mut self, system: System) -> KnnResult { - let (tx, rx) = oneshot::channel(); - self.result_channel = Some(tx); - let mut handle = system.start_component(self); - let result = rx.await; - handle.stop(); - result? - } - - fn terminate_with_error(&mut self, ctx: &ComponentContext, err: E) - where - E: Into, - { - let knn_err = err.into(); - tracing::error!("Error running orchestrator: {}", &knn_err); - terminate_with_error(self.result_channel.take(), knn_err, ctx); - } - async fn try_start_knn_merge_operator(&mut self, ctx: &ComponentContext) { if self.heads_searched && self.num_outstanding_bf_pl == 0 { // This is safe because self.records is only used once and that is during merge. @@ -155,24 +138,23 @@ impl SpannKnnOrchestrator { SpannKnnMergeInput { records }, ctx.receiver(), ); - if let Err(err) = self.dispatcher.send(task, Some(Span::current())).await { - self.terminate_with_error(ctx, err); - } + self.send(task, ctx).await; } } } #[async_trait] -impl Component for SpannKnnOrchestrator { - fn get_name() -> &'static str { - "Spann Knn Orchestrator" - } +impl Orchestrator for SpannKnnOrchestrator { + type Output = KnnOutput; + type Error = KnnError; - fn queue_size(&self) -> usize { - self.queue + fn dispatcher(&self) -> ComponentHandle { + self.dispatcher.clone() } - async fn on_start(&mut self, ctx: &ComponentContext) { + fn initial_tasks(&self, ctx: &ComponentContext) -> Vec { + let mut tasks = Vec::new(); + let knn_log_task = wrap( Box::new(self.log_knn.clone()), KnnLogInput { @@ -184,16 +166,8 @@ impl Component for SpannKnnOrchestrator { }, ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(knn_log_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - return; - } + tasks.push(knn_log_task); - // Invoke Head search operator. let reader_context = SpannSegmentReaderContext { segment: self.knn_filter_output.vector_segment.clone(), blockfile_provider: self.blockfile_provider.clone(), @@ -212,14 +186,23 @@ impl Component for SpannKnnOrchestrator { }, ctx.receiver(), ); + tasks.push(head_search_task); - if let Err(err) = self - .dispatcher - .send(head_search_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } + tasks + } + + fn queue_size(&self) -> usize { + self.queue + } + + fn set_result_channel(&mut self, sender: Sender) { + self.result_channel = Some(sender) + } + + fn take_result_channel(&mut self) -> Sender { + self.result_channel + .take() + .expect("The result channel should be set before take") } } @@ -232,12 +215,9 @@ impl Handler> for SpannKnnOrchestrator { message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; self.records.push(output.record_distances); self.try_start_knn_merge_operator(ctx).await; @@ -255,12 +235,9 @@ impl Handler> message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; // Set state that is used for tracking when we are ready for merging. self.heads_searched = true; @@ -283,13 +260,7 @@ impl Handler> ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(fetch_pl_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } + self.send(fetch_pl_task, ctx).await; } } } @@ -303,12 +274,9 @@ impl Handler> for SpannKnnOrch message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; // Spawn brute force posting list task. let bf_pl_task = wrap( @@ -327,13 +295,7 @@ impl Handler> for SpannKnnOrch ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(bf_pl_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } + self.send(bf_pl_task, ctx).await; } } @@ -346,12 +308,9 @@ impl Handler> for SpannKnnOrchestrat message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } + let output = match self.ok_or_terminate(message.into_inner(), ctx) { + Some(output) => output, + None => return, }; // Update state tracking for merging. self.num_outstanding_bf_pl -= 1; @@ -389,13 +348,7 @@ impl Handler> for SpannKnnOr }, ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(prefetch_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } + self.send(prefetch_task, ctx).await; let projection_task = wrap( Box::new(self.knn_projection.clone()), @@ -407,13 +360,7 @@ impl Handler> for SpannKnnOr }, ctx.receiver(), ); - if let Err(err) = self - .dispatcher - .send(projection_task, Some(Span::current())) - .await - { - self.terminate_with_error(ctx, err); - } + self.send(projection_task, ctx).await; } } @@ -439,17 +386,6 @@ impl Handler> for SpannKnnOr message: TaskResult, ctx: &ComponentContext, ) { - let output = match message.into_inner() { - Ok(output) => output, - Err(err) => { - self.terminate_with_error(ctx, err); - return; - } - }; - if let Some(chan) = self.result_channel.take() { - if chan.send(Ok(output)).is_err() { - tracing::error!("Error sending final result"); - }; - } + self.terminate_with_result(message.into_inner().map_err(|e| e.into()), ctx); } } diff --git a/rust/worker/src/execution/worker_thread.rs b/rust/worker/src/execution/worker_thread.rs index 9a968980247..da5c54d59e8 100644 --- a/rust/worker/src/execution/worker_thread.rs +++ b/rust/worker/src/execution/worker_thread.rs @@ -45,7 +45,7 @@ impl Component for WorkerThread { ComponentRuntime::Dedicated } - async fn on_start(&mut self, ctx: &ComponentContext) { + async fn start(&mut self, ctx: &ComponentContext) { let req = TaskRequestMessage::new(ctx.receiver()); let _req = self.dispatcher.send(req, None).await; // TODO: what to do with resp? diff --git a/rust/worker/src/lib.rs b/rust/worker/src/lib.rs index 17d5b1249d7..83e9c813d1d 100644 --- a/rust/worker/src/lib.rs +++ b/rust/worker/src/lib.rs @@ -1,10 +1,8 @@ mod assignment; mod compactor; -mod config; mod memberlist; mod server; mod sysdb; -mod system; mod tracing; mod utils; @@ -15,9 +13,11 @@ use tokio::select; use tokio::signal::unix::{signal, SignalKind}; // Required for benchmark +pub mod config; pub mod execution; pub mod log; pub mod segment; +pub mod system; const CONFIG_PATH_ENV_VAR: &str = "CONFIG_PATH"; diff --git a/rust/worker/src/log/log.rs b/rust/worker/src/log/log.rs index 9681a99b7ca..c4e5a9d5532 100644 --- a/rust/worker/src/log/log.rs +++ b/rust/worker/src/log/log.rs @@ -40,7 +40,7 @@ pub(crate) struct CollectionRecord { } #[derive(Clone, Debug)] -pub(crate) enum Log { +pub enum Log { Grpc(GrpcLog), #[allow(dead_code)] InMemory(InMemoryLog), @@ -95,7 +95,7 @@ impl Log { } #[derive(Clone, Debug)] -pub(crate) struct GrpcLog { +pub struct GrpcLog { #[allow(clippy::type_complexity)] client: LogServiceClient< interceptor::InterceptedService< @@ -329,7 +329,7 @@ impl ChromaError for UpdateCollectionLogOffsetError { // This is used for testing only, it represents a log record that is stored in memory // internal to a mock log implementation #[derive(Clone)] -pub(crate) struct InternalLogRecord { +pub struct InternalLogRecord { pub(crate) collection_id: CollectionUuid, pub(crate) log_offset: i64, pub(crate) log_ts: i64, @@ -349,13 +349,12 @@ impl Debug for InternalLogRecord { // This is used for testing only #[derive(Clone, Debug)] -pub(crate) struct InMemoryLog { +pub struct InMemoryLog { collection_to_log: HashMap>, offsets: HashMap, } impl InMemoryLog { - #[cfg(test)] pub fn new() -> InMemoryLog { InMemoryLog { collection_to_log: HashMap::new(), @@ -450,3 +449,9 @@ impl InMemoryLog { Ok(()) } } + +impl Default for InMemoryLog { + fn default() -> Self { + Self::new() + } +} diff --git a/rust/worker/src/log/mod.rs b/rust/worker/src/log/mod.rs index 3af94d9aed6..7a9c1047ce5 100644 --- a/rust/worker/src/log/mod.rs +++ b/rust/worker/src/log/mod.rs @@ -1,6 +1,6 @@ pub(crate) mod config; #[allow(clippy::module_inception)] -pub(crate) mod log; +pub mod log; #[allow(dead_code)] pub mod test; diff --git a/rust/worker/src/log/test.rs b/rust/worker/src/log/test.rs index ff7c12f8c0f..5c098cb27b0 100644 --- a/rust/worker/src/log/test.rs +++ b/rust/worker/src/log/test.rs @@ -8,35 +8,33 @@ use rand::{ pub const TEST_EMBEDDING_DIMENSION: usize = 6; -pub struct LogGenerator -where - G: Fn(usize) -> OperationRecord, -{ - pub generator: G, +pub trait LogGenerator { + fn generate_vec(&self, offsets: O) -> Vec + where + O: Iterator; + fn generate_chunk(&self, offsets: O) -> Chunk + where + O: Iterator, + { + Chunk::new(self.generate_vec(offsets).into()) + } } -impl LogGenerator +impl LogGenerator for G where G: Fn(usize) -> OperationRecord, { - pub fn generate_vec(&self, offsets: O) -> Vec + fn generate_vec(&self, offsets: O) -> Vec where O: Iterator, { offsets .map(|log_offset| LogRecord { log_offset: log_offset as i64, - record: (self.generator)(log_offset), + record: self(log_offset), }) .collect() } - - pub fn generate_chunk(&self, offsets: O) -> Chunk - where - O: Iterator, - { - Chunk::new(self.generate_vec(offsets).into()) - } } pub fn int_as_id(value: usize) -> String { diff --git a/rust/worker/src/memberlist/memberlist_provider.rs b/rust/worker/src/memberlist/memberlist_provider.rs index c8462a6cc82..0857ebc8f38 100644 --- a/rust/worker/src/memberlist/memberlist_provider.rs +++ b/rust/worker/src/memberlist/memberlist_provider.rs @@ -183,7 +183,7 @@ impl Component for CustomResourceMemberlistProvider { self.queue_size } - async fn on_start(&mut self, ctx: &ComponentContext) { + async fn start(&mut self, ctx: &ComponentContext) { self.connect_to_kube_stream(ctx); } } diff --git a/rust/worker/src/segment/distributed_hnsw_segment.rs b/rust/worker/src/segment/distributed_hnsw_segment.rs index e22807897f9..53e51e481d5 100644 --- a/rust/worker/src/segment/distributed_hnsw_segment.rs +++ b/rust/worker/src/segment/distributed_hnsw_segment.rs @@ -27,7 +27,7 @@ pub struct HnswIndexParamsFromSegment { } #[derive(Clone)] -pub(crate) struct DistributedHNSWSegmentWriter { +pub struct DistributedHNSWSegmentWriter { index: HnswIndexRef, hnsw_index_provider: HnswIndexProvider, pub(crate) id: SegmentUuid, @@ -86,7 +86,7 @@ impl DistributedHNSWSegmentWriter { } } - pub(crate) async fn from_segment( + pub async fn from_segment( segment: &Segment, dimensionality: usize, hnsw_index_provider: HnswIndexProvider, @@ -96,7 +96,6 @@ impl DistributedHNSWSegmentWriter { // ideally, an explicit state would be better. When we implement distributed HNSW segments, // we can introduce a state in the segment metadata for this if !segment.file_path.is_empty() { - println!("Loading HNSW index from files"); // Check if its in the providers cache, if not load the index from the files let index_id = match &segment.file_path.get(HNSW_INDEX) { None => { @@ -272,9 +271,9 @@ impl SegmentFlusher for DistributedHNSWSegmentWriter { } #[derive(Clone)] -pub(crate) struct DistributedHNSWSegmentReader { +pub struct DistributedHNSWSegmentReader { index: HnswIndexRef, - pub(crate) id: SegmentUuid, + pub id: SegmentUuid, } impl Debug for DistributedHNSWSegmentReader { @@ -300,7 +299,6 @@ impl DistributedHNSWSegmentReader { // ideally, an explicit state would be better. When we implement distributed HNSW segments, // we can introduce a state in the segment metadata for this if !segment.file_path.is_empty() { - println!("Loading HNSW index from files"); // Check if its in the providers cache, if not load the index from the files let index_id = match &segment.file_path.get(HNSW_INDEX) { None => { diff --git a/rust/worker/src/segment/mod.rs b/rust/worker/src/segment/mod.rs index 14920256194..a3d8b648e26 100644 --- a/rust/worker/src/segment/mod.rs +++ b/rust/worker/src/segment/mod.rs @@ -1,11 +1,11 @@ pub(crate) mod config; -pub(crate) mod distributed_hnsw_segment; pub mod test; pub(crate) mod utils; pub(crate) use types::*; // Required for benchmark +pub mod distributed_hnsw_segment; pub mod metadata_segment; pub mod record_segment; pub mod spann_segment; diff --git a/rust/worker/src/segment/record_segment.rs b/rust/worker/src/segment/record_segment.rs index f69ad5e7f04..3580b6eb374 100644 --- a/rust/worker/src/segment/record_segment.rs +++ b/rust/worker/src/segment/record_segment.rs @@ -11,9 +11,10 @@ use chroma_index::fulltext::types::FullTextIndexError; use chroma_types::{ Chunk, DataRecord, MaterializedLogOperation, Segment, SegmentType, SegmentUuid, }; -use std::cmp::Ordering; +use futures::{Stream, StreamExt}; use std::collections::HashMap; use std::fmt::{self, Debug, Formatter}; +use std::ops::RangeBounds; use std::sync::atomic::{self, AtomicU32}; use std::sync::Arc; use thiserror::Error; @@ -827,83 +828,33 @@ impl RecordSegmentReader<'_> { self.id_to_data.contains("", offset_id).await } - /// Returns all data in the record segment, sorted by - /// embedding id + /// Returns all data in the record segment, sorted by their offset ids #[allow(dead_code)] pub(crate) async fn get_all_data(&self) -> Result, Box> { - let mut data = Vec::new(); - let max_size = self.user_id_to_id.count().await?; - for i in 0..max_size { - let res = self.user_id_to_id.get_at_index(i).await; - match res { - Ok((_, _, offset_id)) => { - if let Some(data_record) = self.id_to_data.get("", offset_id).await? { - data.push(data_record); - } else { - return Err( - Box::new(RecordSegmentReaderCreationError::DataRecordNotFound( - offset_id, - )) as _, - ); - } - } - Err(e) => { - tracing::error!( - "[GetAllData] Error getting data record for index {:?}: {:?}", - i, - e - ); - return Err(e); - } - } - } - Ok(data) + self.id_to_data + .get_range(""..="", ..) + .await + .map(|vec| vec.into_iter().map(|(_, data)| data).collect()) } - pub(crate) async fn get_offset_id_at_index( - &self, - index: usize, - ) -> Result> { - match self.id_to_user_id.get_at_index(index).await { - Ok((_, oid, _)) => Ok(oid), - Err(e) => { - tracing::error!( - "[GetAllData] Error getting offset id for index {}: {}", - index, - e - ); - Err(e) - } - } + /// Get a stream of offset ids from the smallest to the largest in the given range + pub(crate) fn get_offset_stream<'me>( + &'me self, + offset_range: impl RangeBounds + Clone + Send + 'me, + ) -> impl Stream>> + 'me { + self.id_to_user_id + .get_range_stream(""..="", offset_range) + .map(|res| res.map(|(offset_id, _)| offset_id)) } - // Find the rank of the given offset id in the record segment - // The implemention is based on std binary search + /// Find the rank of the given offset id in the record segment + /// The rank of an offset id is the number of offset ids strictly smaller than it + /// In other words, it is the position where the given offset id can be inserted without breaking the order pub(crate) async fn get_offset_id_rank( &self, target_oid: u32, ) -> Result> { - let mut size = self.count().await?; - if size == 0 { - return Ok(0); - } - let mut base = 0; - while size > 1 { - let half = size / 2; - let mid = base + half; - - let cmp = self.get_offset_id_at_index(mid).await?.cmp(&target_oid); - base = if cmp == Ordering::Greater { base } else { mid }; - size -= half; - } - - Ok( - match self.get_offset_id_at_index(base).await?.cmp(&target_oid) { - Ordering::Equal => base, - Ordering::Less => base + 1, - Ordering::Greater => base, - }, - ) + self.id_to_user_id.rank("", target_oid).await } pub(crate) async fn count(&self) -> Result> { @@ -948,17 +899,18 @@ mod tests { // The same record segment writer should be able to run concurrently on different threads without conflict #[test] fn test_max_offset_id_shuttle() { + let test_segment = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("Runtime creation should not fail") + .block_on(async { TestSegment::default() }); shuttle::check_random( - || { + move || { let log_partition_size = 100; let stack_size = 1 << 22; let thread_count = 4; - let log_generator = LogGenerator { - generator: upsert_generator, - }; let max_log_offset = thread_count * log_partition_size; - let logs = log_generator.generate_vec(1..=max_log_offset); - let test_segment = TestSegment::default(); + let logs = upsert_generator.generate_vec(1..=max_log_offset); let batches = logs .chunks(log_partition_size) diff --git a/rust/worker/src/segment/spann_segment.rs b/rust/worker/src/segment/spann_segment.rs index be5e7cbf2a7..c4d207096e7 100644 --- a/rust/worker/src/segment/spann_segment.rs +++ b/rust/worker/src/segment/spann_segment.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_distance::DistanceFunctionError; use chroma_error::{ChromaError, ErrorCodes}; @@ -11,7 +12,6 @@ use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWri use chroma_types::SegmentUuid; use chroma_types::{MaterializedLogOperation, Segment, SegmentScope, SegmentType}; use thiserror::Error; -use tonic::async_trait; use uuid::Uuid; use super::{ diff --git a/rust/worker/src/segment/test.rs b/rust/worker/src/segment/test.rs index 25b9e8cbb90..6879058ae4e 100644 --- a/rust/worker/src/segment/test.rs +++ b/rust/worker/src/segment/test.rs @@ -1,20 +1,24 @@ use std::sync::atomic::AtomicU32; use chroma_blockstore::{provider::BlockfileProvider, test_arrow_blockfile_provider}; +use chroma_index::{hnsw_provider::HnswIndexProvider, test_hnsw_index_provider}; use chroma_types::{ - test_segment, Chunk, Collection, CollectionUuid, LogRecord, OperationRecord, Segment, + test_segment, Chunk, Collection, CollectionAndSegments, CollectionUuid, LogRecord, Segment, SegmentScope, }; use crate::log::test::{LogGenerator, TEST_EMBEDDING_DIMENSION}; use super::{ - materialize_logs, metadata_segment::MetadataSegmentWriter, record_segment::RecordSegmentWriter, - SegmentFlusher, SegmentWriter, + distributed_hnsw_segment::DistributedHNSWSegmentWriter, materialize_logs, + metadata_segment::MetadataSegmentWriter, record_segment::RecordSegmentWriter, SegmentFlusher, + SegmentWriter, }; +#[derive(Clone)] pub struct TestSegment { pub blockfile_provider: BlockfileProvider, + pub hnsw_provider: HnswIndexProvider, pub collection: Collection, pub metadata_segment: Segment, pub record_segment: Segment, @@ -22,8 +26,30 @@ pub struct TestSegment { } impl TestSegment { + pub fn new_with_dimension(dimension: usize) -> Self { + let collection_uuid = CollectionUuid::new(); + let collection = Collection { + collection_id: collection_uuid, + name: "Test Collection".to_string(), + metadata: None, + dimension: Some(dimension as i32), + tenant: "Test Tenant".to_string(), + database: String::new(), + log_position: 0, + version: 0, + }; + Self { + blockfile_provider: test_arrow_blockfile_provider(2 << 22), + hnsw_provider: test_hnsw_index_provider(), + collection, + metadata_segment: test_segment(collection_uuid, SegmentScope::METADATA), + record_segment: test_segment(collection_uuid, SegmentScope::RECORD), + vector_segment: test_segment(collection_uuid, SegmentScope::VECTOR), + } + } + // WARN: The size of the log chunk should not be too large - async fn compact_log(&mut self, logs: Chunk, next_offset: usize) { + pub async fn compact_log(&mut self, logs: Chunk, next_offset: usize) { let materialized_logs = materialize_logs( &None, &logs, @@ -57,22 +83,45 @@ impl TestSegment { .await .expect("Should be able to initiaize record writer."); record_writer - .apply_materialized_log_chunk(materialized_logs) + .apply_materialized_log_chunk(materialized_logs.clone()) .await .expect("Should be able to apply materialized log."); self.record_segment.file_path = record_writer .commit() .await - .expect("Should be able to commit metadata.") + .expect("Should be able to commit record.") .flush() .await - .expect("Should be able to flush metadata."); + .expect("Should be able to flush record."); + + let vector_writer = DistributedHNSWSegmentWriter::from_segment( + &self.vector_segment, + self.collection + .dimension + .expect("Collection dimension should be set") as usize, + self.hnsw_provider.clone(), + ) + .await + .expect("Should be able to initialize vector writer"); + + vector_writer + .apply_materialized_log_chunk(materialized_logs) + .await + .expect("Should be able to apply materialized log."); + + self.vector_segment.file_path = vector_writer + .commit() + .await + .expect("Should be able to commit vector.") + .flush() + .await + .expect("Should be able to flush vector."); } - pub async fn populate_with_generator(&mut self, size: usize, generator: &LogGenerator) + pub async fn populate_with_generator(&mut self, size: usize, generator: G) where - G: Fn(usize) -> OperationRecord, + G: LogGenerator, { let ids: Vec<_> = (1..=size).collect(); for chunk in ids.chunks(100) { @@ -90,23 +139,17 @@ impl TestSegment { impl Default for TestSegment { fn default() -> Self { - let collection_uuid = CollectionUuid::new(); - let collection = Collection { - collection_id: collection_uuid, - name: "Test Collection".to_string(), - metadata: None, - dimension: Some(TEST_EMBEDDING_DIMENSION as i32), - tenant: "Test Tenant".to_string(), - database: String::new(), - log_position: 0, - version: 0, - }; + Self::new_with_dimension(TEST_EMBEDDING_DIMENSION) + } +} + +impl From for CollectionAndSegments { + fn from(value: TestSegment) -> Self { Self { - blockfile_provider: test_arrow_blockfile_provider(2 << 22), - collection, - metadata_segment: test_segment(collection_uuid, SegmentScope::METADATA), - record_segment: test_segment(collection_uuid, SegmentScope::RECORD), - vector_segment: test_segment(collection_uuid, SegmentScope::VECTOR), + collection: value.collection, + metadata_segment: value.metadata_segment, + record_segment: value.record_segment, + vector_segment: value.vector_segment, } } } diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index bda9641b300..b58dec8b018 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -1,5 +1,6 @@ -use std::{iter::once, str::FromStr}; +use std::iter::once; +use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_config::Configurable; use chroma_error::ChromaError; @@ -9,23 +10,22 @@ use chroma_types::{ self, query_executor_server::QueryExecutor, CountPlan, CountResult, GetPlan, GetResult, KnnBatchResult, KnnPlan, }, - CollectionUuid, SegmentUuid, + CollectionAndSegments, }; use futures::{stream, StreamExt, TryStreamExt}; use tokio::signal::unix::{signal, SignalKind}; use tonic::{transport::Server, Request, Response, Status}; use tracing::{trace_span, Instrument}; -use uuid::Uuid; use crate::{ config::QueryServiceConfig, execution::{ dispatcher::Dispatcher, - operators::{ - fetch_log::FetchLogOperator, fetch_segment::FetchSegmentOperator, - knn_projection::KnnProjectionOperator, + operators::{fetch_log::FetchLogOperator, knn_projection::KnnProjectionOperator}, + orchestration::{ + get::GetOrchestrator, knn::KnnOrchestrator, knn_filter::KnnFilterOrchestrator, + orchestrator::Orchestrator, CountOrchestrator, }, - orchestration::{get::GetOrchestrator, knn::KnnOrchestrator, knn_filter::{KnnError, KnnFilterOrchestrator}, CountQueryOrchestrator}, }, log::log::Log, sysdb::sysdb::SysDb, @@ -42,13 +42,13 @@ pub struct WorkerServer { dispatcher: Option>, // Service dependencies log: Box, - sysdb: Box, + _sysdb: Box, hnsw_index_provider: HnswIndexProvider, blockfile_provider: BlockfileProvider, port: u16, } -#[async_trait::async_trait] +#[async_trait] impl Configurable for WorkerServer { async fn try_from_config(config: &QueryServiceConfig) -> Result> { let sysdb_config = &config.sysdb; @@ -86,7 +86,7 @@ impl Configurable for WorkerServer { Ok(WorkerServer { dispatcher: None, system: None, - sysdb, + _sysdb: sysdb, log, hnsw_index_provider, blockfile_provider, @@ -132,46 +132,17 @@ impl WorkerServer { self.system = Some(system); } - fn decompose_proto_scan( - &self, - scan: chroma_proto::ScanOperator, - ) -> Result<(FetchLogOperator, FetchSegmentOperator), Status> { - let collection = scan - .collection - .ok_or(Status::invalid_argument("Invalid Collection"))?; - - let collection_uuid = CollectionUuid::from_str(&collection.id) - .map_err(|_| Status::invalid_argument("Invalid Collection UUID"))?; - - let vector_uuid = SegmentUuid::from_str(&scan.knn_id) - .map_err(|_| Status::invalid_argument("Invalid UUID for Vector segment"))?; - - let metadata_uuid = SegmentUuid::from_str(&scan.metadata_id) - .map_err(|_| Status::invalid_argument("Invalid UUID for Metadata segment"))?; - - let record_uuid = SegmentUuid::from_str(&scan.record_id) - .map_err(|_| Status::invalid_argument("Invalid UUID for Record segment"))?; - - Ok(( - FetchLogOperator { - log_client: self.log.clone(), - // TODO: Make this configurable - batch_size: 100, - // The collection log position is inclusive, and we want to start from the next log - // Note that we query using the incoming log position this is critical for correctness - start_log_offset_id: collection.log_position as u32 + 1, - maximum_fetch_count: None, - collection_uuid, - }, - FetchSegmentOperator { - sysdb: self.sysdb.clone(), - collection_uuid, - collection_version: collection.version as u32, - metadata_uuid, - record_uuid, - vector_uuid, - }, - )) + fn fetch_log(&self, collection_and_segments: &CollectionAndSegments) -> FetchLogOperator { + FetchLogOperator { + log_client: self.log.clone(), + // TODO: Make this configurable + batch_size: 100, + // The collection log position is inclusive, and we want to start from the next log + // Note that we query using the incoming log position this is critical for correctness + start_log_offset_id: collection_and_segments.collection.log_position as u32 + 1, + maximum_fetch_count: None, + collection_uuid: collection_and_segments.collection.collection_id, + } } async fn orchestrate_count( @@ -183,25 +154,19 @@ impl WorkerServer { .scan .ok_or(Status::invalid_argument("Invalid Scan Operator"))?; - let collection = &scan - .collection - .ok_or(Status::invalid_argument("Invalid collection"))?; - - let count_orchestrator = CountQueryOrchestrator::new( - self.clone_system()?, - &Uuid::parse_str(&scan.metadata_id) - .map_err(|e| Status::invalid_argument(e.to_string()))?, - &CollectionUuid::from_str(&collection.id) - .map_err(|e| Status::invalid_argument(e.to_string()))?, - self.log.clone(), - self.sysdb.clone(), - self.clone_dispatcher()?, + let collection_and_segments = scan.try_into()?; + let fetch_log = self.fetch_log(&collection_and_segments); + + let count_orchestrator = CountOrchestrator::new( self.blockfile_provider.clone(), - collection.version as u32, - collection.log_position as u64, + self.clone_dispatcher()?, + // TODO: Make this configurable + 1000, + collection_and_segments, + fetch_log, ); - match count_orchestrator.run().await { + match count_orchestrator.run(self.clone_system()?).await { Ok(count) => Ok(Response::new(CountResult { count: count as u32, })), @@ -215,7 +180,8 @@ impl WorkerServer { .scan .ok_or(Status::invalid_argument("Invalid Scan Operator"))?; - let (fetch_log_operator, fetch_segment_operator) = self.decompose_proto_scan(scan)?; + let collection_and_segments = scan.try_into()?; + let fetch_log = self.fetch_log(&collection_and_segments); let filter = get_inner .filter @@ -234,8 +200,8 @@ impl WorkerServer { self.clone_dispatcher()?, // TODO: Make this configurable 1000, - fetch_log_operator, - fetch_segment_operator, + collection_and_segments, + fetch_log, filter.try_into()?, limit.into(), projection.into(), @@ -260,7 +226,9 @@ impl WorkerServer { .scan .ok_or(Status::invalid_argument("Invalid Scan Operator"))?; - let (fetch_log_operator, fetch_segment_operator) = self.decompose_proto_scan(scan)?; + let collection_and_segments = scan.try_into()?; + + let fetch_log = self.fetch_log(&collection_and_segments); let filter = knn_inner .filter @@ -280,27 +248,32 @@ impl WorkerServer { return Ok(Response::new(to_proto_knn_batch_result(Vec::new())?)); } + // If dimension is not set and segment is uninitialized, we assume + // this is a query on empty collection, so we return early here + if collection_and_segments.collection.dimension.is_none() + && collection_and_segments.vector_segment.file_path.is_empty() + { + return Ok(Response::new(to_proto_knn_batch_result( + once(Default::default()) + .cycle() + .take(knn.embeddings.len()) + .collect(), + )?)); + } + let knn_filter_orchestrator = KnnFilterOrchestrator::new( self.blockfile_provider.clone(), dispatcher.clone(), self.hnsw_index_provider.clone(), // TODO: Make this configurable 1000, - fetch_log_operator, - fetch_segment_operator, + collection_and_segments, + fetch_log, filter.try_into()?, ); let matching_records = match knn_filter_orchestrator.run(system.clone()).await { Ok(output) => output, - Err(KnnError::EmptyCollection) => { - return Ok(Response::new(to_proto_knn_batch_result( - once(Default::default()) - .cycle() - .take(knn.embeddings.len()) - .collect(), - )?)); - } Err(e) => { return Err(Status::new(e.code().into(), e.to_string())); } @@ -346,7 +319,7 @@ impl WorkerServer { } } -#[tonic::async_trait] +#[async_trait] impl QueryExecutor for WorkerServer { async fn count(&self, count: Request) -> Result, Status> { // Note: We cannot write a middleware that instruments every service rpc @@ -389,7 +362,7 @@ impl QueryExecutor for WorkerServer { } #[cfg(debug_assertions)] -#[tonic::async_trait] +#[async_trait] impl chroma_proto::debug_server::Debug for WorkerServer { async fn get_info( &self, @@ -422,6 +395,8 @@ impl chroma_proto::debug_server::Debug for WorkerServer { #[cfg(test)] mod tests { + use std::collections::HashMap; + use super::*; use crate::execution::dispatcher; use crate::log::log::InMemoryLog; @@ -443,7 +418,7 @@ mod tests { let mut server = WorkerServer { dispatcher: None, system: None, - sysdb: Box::new(SysDb::Test(sysdb)), + _sysdb: Box::new(SysDb::Test(sysdb)), log: Box::new(Log::InMemory(log)), hnsw_index_provider: test_hnsw_index_provider(), blockfile_provider: segments.blockfile_provider, @@ -465,21 +440,43 @@ mod tests { } fn scan() -> chroma_proto::ScanOperator { + let collection_id = Uuid::new_v4().to_string(); chroma_proto::ScanOperator { collection: Some(chroma_proto::Collection { - id: Uuid::new_v4().to_string(), - name: "Test-Collection".to_string(), + id: collection_id.clone(), + name: "test-collection".to_string(), configuration_json_str: String::new(), metadata: None, dimension: None, - tenant: "Test-Tenant".to_string(), - database: "Test-Database".to_string(), + tenant: "test-tenant".to_string(), + database: "test-database".to_string(), log_position: 0, version: 0, }), - knn_id: Uuid::new_v4().to_string(), - metadata_id: Uuid::new_v4().to_string(), - record_id: Uuid::new_v4().to_string(), + knn: Some(chroma_proto::Segment { + id: Uuid::new_v4().to_string(), + r#type: "urn:chroma:segment/vector/hnsw-distributed".to_string(), + scope: 0, + collection: collection_id.clone(), + metadata: None, + file_paths: HashMap::new(), + }), + metadata: Some(chroma_proto::Segment { + id: Uuid::new_v4().to_string(), + r#type: "urn:chroma:segment/metadata/blockfile".to_string(), + scope: 1, + collection: collection_id.clone(), + metadata: None, + file_paths: HashMap::new(), + }), + record: Some(chroma_proto::Segment { + id: Uuid::new_v4().to_string(), + r#type: "urn:chroma:segment/record/blockfile".to_string(), + scope: 2, + collection: collection_id.clone(), + metadata: None, + file_paths: HashMap::new(), + }), } } @@ -501,15 +498,19 @@ mod tests { async fn validate_count_plan() { let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap(); let mut scan_operator = scan(); - let request = chroma_proto::CountPlan { - scan: Some(scan_operator.clone()), - }; - - // segment or collection not found - let response = executor.count(request).await; - assert_eq!(response.unwrap_err().code(), tonic::Code::NotFound); - - scan_operator.metadata_id = "invalid_segment_id".to_string(); + scan_operator.metadata = Some(chroma_proto::Segment { + id: "invalid-metadata-segment-id".to_string(), + r#type: "urn:chroma:segment/metadata/blockfile".to_string(), + scope: 1, + collection: scan_operator + .collection + .as_ref() + .expect("The collection should exist") + .id + .clone(), + metadata: None, + file_paths: HashMap::new(), + }); let request = chroma_proto::CountPlan { scan: Some(scan_operator.clone()), }; @@ -524,29 +525,6 @@ mod tests { async fn validate_get_plan() { let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap(); let mut scan_operator = scan(); - let request = chroma_proto::GetPlan { - scan: Some(scan_operator.clone()), - filter: Some(chroma_proto::FilterOperator { - ids: None, - r#where: None, - where_document: None, - }), - limit: Some(chroma_proto::LimitOperator { - skip: 0, - fetch: None, - }), - projection: Some(chroma_proto::ProjectionOperator { - document: false, - embedding: false, - metadata: false, - }), - }; - - // segment or collection not found - let response = executor.get(request.clone()).await; - assert!(response.is_err()); - assert_eq!(response.unwrap_err().code(), tonic::Code::NotFound); - let request = chroma_proto::GetPlan { scan: Some(scan_operator.clone()), filter: None, @@ -567,13 +545,13 @@ mod tests { assert_eq!(response.unwrap_err().code(), tonic::Code::InvalidArgument); scan_operator.collection = Some(chroma_proto::Collection { - id: "Invalid-Collection-ID".to_string(), - name: "Broken-Collection".to_string(), + id: "invalid-collection-iD".to_string(), + name: "broken-collection".to_string(), configuration_json_str: String::new(), metadata: None, dimension: None, - tenant: "Test-Tenant".to_string(), - database: "Test-Database".to_string(), + tenant: "test-tenant".to_string(), + database: "test-database".to_string(), log_position: 0, version: 0, }); @@ -601,7 +579,9 @@ mod tests { assert_eq!(response.unwrap_err().code(), tonic::Code::InvalidArgument); } - fn gen_knn_request(mut scan_operator: Option) -> chroma_proto::KnnPlan { + fn gen_knn_request( + mut scan_operator: Option, + ) -> chroma_proto::KnnPlan { if scan_operator.is_none() { scan_operator = Some(scan()); } @@ -716,16 +696,11 @@ mod tests { async fn validate_knn_plan_scan_collection() { let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap(); let mut scan = scan(); - scan.collection.as_mut().unwrap().id = "Invalid-Collection-ID".to_string(); + scan.collection.as_mut().unwrap().id = "invalid-collection-id".to_string(); let response = executor.knn(gen_knn_request(Some(scan))).await; assert!(response.is_err()); let err = response.unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); - assert!( - err.message().to_lowercase().contains("collection uuid"), - "{}", - err.message() - ); } #[tokio::test] @@ -733,47 +708,68 @@ mod tests { let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap(); // invalid vector uuid let mut scan_operator = scan(); - scan_operator.knn_id = "invalid_segment_id".to_string(); + scan_operator.knn = Some(chroma_proto::Segment { + id: "invalid-knn-segment-id".to_string(), + r#type: "urn:chroma:segment/vector/hnsw-distributed".to_string(), + scope: 0, + collection: scan_operator + .collection + .as_ref() + .expect("The collection should exist") + .id + .clone(), + metadata: None, + file_paths: HashMap::new(), + }); let response = executor.knn(gen_knn_request(Some(scan_operator))).await; assert!(response.is_err()); let err = response.unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); - assert!( - err.message().to_lowercase().contains("vector"), - "{}", - err.message() - ); } #[tokio::test] async fn validate_knn_plan_scan_record() { let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap(); let mut scan_operator = scan(); - scan_operator.record_id = "invalid_record_id".to_string(); + scan_operator.record = Some(chroma_proto::Segment { + id: "invalid-record-segment-id".to_string(), + r#type: "urn:chroma:segment/record/blockfile".to_string(), + scope: 2, + collection: scan_operator + .collection + .as_ref() + .expect("The collection should exist") + .id + .clone(), + metadata: None, + file_paths: HashMap::new(), + }); let response = executor.knn(gen_knn_request(Some(scan_operator))).await; assert!(response.is_err()); let err = response.unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); - assert!( - err.message().to_lowercase().contains("record"), - "{}", - err.message() - ); } #[tokio::test] async fn validate_knn_plan_scan_metadata() { let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap(); let mut scan_operator = scan(); - scan_operator.metadata_id = "invalid_metadata_id".to_string(); + scan_operator.metadata = Some(chroma_proto::Segment { + id: "invalid-metadata-segment-id".to_string(), + r#type: "urn:chroma:segment/metadata/blockfile".to_string(), + scope: 1, + collection: scan_operator + .collection + .as_ref() + .expect("The collection should exist") + .id + .clone(), + metadata: None, + file_paths: HashMap::new(), + }); let response = executor.knn(gen_knn_request(Some(scan_operator))).await; assert!(response.is_err()); let err = response.unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); - assert!( - err.message().to_lowercase().contains("metadata"), - "{}", - err.message() - ); } } diff --git a/rust/worker/src/system/executor.rs b/rust/worker/src/system/executor.rs index 2f064a8067e..09779e45340 100644 --- a/rust/worker/src/system/executor.rs +++ b/rust/worker/src/system/executor.rs @@ -53,7 +53,7 @@ where mut channel: tokio::sync::mpsc::Receiver>, ) { self.handler - .on_start(&ComponentContext { + .start(&ComponentContext { system: self.inner.system.clone(), sender: self.inner.sender.clone(), cancellation_token: self.inner.cancellation_token.clone(), diff --git a/rust/worker/src/system/mod.rs b/rust/worker/src/system/mod.rs index 9c9b7117faa..7c92abb7e26 100644 --- a/rust/worker/src/system/mod.rs +++ b/rust/worker/src/system/mod.rs @@ -8,6 +8,6 @@ mod wrapped_message; // Re-export types pub(crate) use receiver::*; -pub(crate) use system::*; -pub(crate) use types::*; +pub use system::*; +pub use types::*; pub(crate) use wrapped_message::*; diff --git a/rust/worker/src/system/scheduler.rs b/rust/worker/src/system/scheduler.rs index 29428da0844..725dca0c671 100644 --- a/rust/worker/src/system/scheduler.rs +++ b/rust/worker/src/system/scheduler.rs @@ -204,7 +204,7 @@ mod tests { self.queue_size } - async fn on_start(&mut self, ctx: &ComponentContext) -> () { + async fn start(&mut self, ctx: &ComponentContext) -> () { let duration = Duration::from_millis(100); ctx.scheduler .schedule(ScheduleMessage {}, duration, ctx, || None); diff --git a/rust/worker/src/system/system.rs b/rust/worker/src/system/system.rs index b78b43b526e..bdd3f955b07 100644 --- a/rust/worker/src/system/system.rs +++ b/rust/worker/src/system/system.rs @@ -14,7 +14,7 @@ use tokio::{pin, select}; use tracing::{trace_span, Instrument, Span}; #[derive(Clone, Debug)] -pub(crate) struct System { +pub struct System { inner: Arc, } @@ -32,7 +32,7 @@ impl System { } } - pub(crate) fn start_component(&self, component: C) -> ComponentHandle + pub fn start_component(&self, component: C) -> ComponentHandle where C: Component + Send + 'static, { @@ -96,6 +96,12 @@ impl System { } } +impl Default for System { + fn default() -> Self { + Self::new() + } +} + async fn stream_loop(stream: S, ctx: &ComponentContext) where C: StreamHandler + Handler, diff --git a/rust/worker/src/system/types.rs b/rust/worker/src/system/types.rs index 339515d8a11..1ff3e9365f7 100644 --- a/rust/worker/src/system/types.rs +++ b/rust/worker/src/system/types.rs @@ -22,7 +22,7 @@ pub(crate) enum ComponentState { } #[derive(Debug, PartialEq, Clone, Copy)] -pub(crate) enum ComponentRuntime { +pub enum ComponentRuntime { Inherit, Dedicated, } @@ -37,13 +37,13 @@ pub(crate) enum ComponentRuntime { /// - queue_size: The size of the queue to use for the component before it starts dropping messages /// - on_start: Called when the component is started #[async_trait] -pub(crate) trait Component: Send + Sized + Debug + 'static { +pub trait Component: Send + Sized + Debug + 'static { fn get_name() -> &'static str; fn queue_size(&self) -> usize; fn runtime() -> ComponentRuntime { ComponentRuntime::Inherit } - async fn on_start(&mut self, _ctx: &ComponentContext) -> () {} + async fn start(&mut self, _ctx: &ComponentContext) -> () {} } /// A handler is a component that can process messages of a given type. @@ -180,7 +180,7 @@ impl Clone for ComponentSender { /// - join_handle: The join handle for the component, used to join on the component /// - sender: A channel to send messages to the component #[derive(Debug)] -pub(crate) struct ComponentHandle { +pub struct ComponentHandle { cancellation_token: tokio_util::sync::CancellationToken, state: Arc>, join_handle: Option, @@ -271,7 +271,7 @@ impl ComponentHandle { } /// The component context is passed to all Component Handler methods -pub(crate) struct ComponentContext +pub struct ComponentContext where C: Component + 'static, { @@ -346,7 +346,7 @@ mod tests { self.queue_size } - async fn on_start(&mut self, ctx: &ComponentContext) -> () { + async fn start(&mut self, ctx: &ComponentContext) -> () { let test_stream = stream::iter(vec![1, 2, 3]); self.register_stream(test_stream, ctx); } diff --git a/rust/worker/src/tracing/opentelemetry_config.rs b/rust/worker/src/tracing/opentelemetry_config.rs index 4a15d5efda7..fa18296e0ff 100644 --- a/rust/worker/src/tracing/opentelemetry_config.rs +++ b/rust/worker/src/tracing/opentelemetry_config.rs @@ -129,7 +129,7 @@ pub(crate) fn init_otel_tracing(service_name: &String, otel_endpoint: &String) { // global filter layer. Don't filter anything at above trace at the global layer for chroma. // And enable errors for every other library. let global_layer = EnvFilter::new(std::env::var("RUST_LOG").unwrap_or_else(|_| { - "error,".to_string() + "info,".to_string() + &vec![ "chroma", "chroma-blockstore", @@ -149,7 +149,7 @@ pub(crate) fn init_otel_tracing(service_name: &String, otel_endpoint: &String) { "worker", ] .into_iter() - .map(|s| s.to_string() + "=trace") + .map(|s| s.to_string() + "=debug") .collect::>() .join(",") }));