|
| 1 | +from typing import Optional, Union, List, Awaitable |
| 2 | + |
| 3 | +from tqdm.asyncio import tqdm |
| 4 | +from asyncio import Semaphore |
| 5 | + |
| 6 | +from .vector_factory_grpc import VectorFactoryGRPC |
| 7 | + |
| 8 | +from pinecone.core.grpc.protos.vector_service_pb2 import ( |
| 9 | + Vector as GRPCVector, |
| 10 | + QueryVector as GRPCQueryVector, |
| 11 | + UpsertRequest, |
| 12 | + UpsertResponse, |
| 13 | + SparseValues as GRPCSparseValues, |
| 14 | +) |
| 15 | +from .base import GRPCIndexBase |
| 16 | +from pinecone import Vector as NonGRPCVector |
| 17 | +from pinecone.core.grpc.protos.vector_service_pb2_grpc import VectorServiceStub |
| 18 | +from pinecone.utils import parse_non_empty_args |
| 19 | + |
| 20 | +from .config import GRPCClientConfig |
| 21 | +from pinecone.config import Config |
| 22 | +from grpc._channel import Channel |
| 23 | + |
| 24 | +__all__ = ["GRPCIndexAsyncio", "GRPCVector", "GRPCQueryVector", "GRPCSparseValues"] |
| 25 | + |
| 26 | + |
| 27 | +class GRPCIndexAsyncio(GRPCIndexBase): |
| 28 | + """A client for interacting with a Pinecone index over GRPC with asyncio.""" |
| 29 | + |
| 30 | + def __init__( |
| 31 | + self, |
| 32 | + index_name: str, |
| 33 | + config: Config, |
| 34 | + channel: Optional[Channel] = None, |
| 35 | + grpc_config: Optional[GRPCClientConfig] = None, |
| 36 | + _endpoint_override: Optional[str] = None, |
| 37 | + ): |
| 38 | + super().__init__( |
| 39 | + index_name=index_name, |
| 40 | + config=config, |
| 41 | + channel=channel, |
| 42 | + grpc_config=grpc_config, |
| 43 | + _endpoint_override=_endpoint_override, |
| 44 | + use_asyncio=True, |
| 45 | + ) |
| 46 | + |
| 47 | + @property |
| 48 | + def stub_class(self): |
| 49 | + return VectorServiceStub |
| 50 | + |
| 51 | + async def upsert( |
| 52 | + self, |
| 53 | + vectors: Union[List[GRPCVector], List[NonGRPCVector], List[tuple], List[dict]], |
| 54 | + namespace: Optional[str] = None, |
| 55 | + batch_size: Optional[int] = None, |
| 56 | + show_progress: bool = True, |
| 57 | + **kwargs, |
| 58 | + ) -> Awaitable[UpsertResponse]: |
| 59 | + timeout = kwargs.pop("timeout", None) |
| 60 | + vectors = list(map(VectorFactoryGRPC.build, vectors)) |
| 61 | + |
| 62 | + if batch_size is None: |
| 63 | + return await self._upsert_batch(vectors, namespace, timeout=timeout, **kwargs) |
| 64 | + |
| 65 | + else: |
| 66 | + if not isinstance(batch_size, int) or batch_size <= 0: |
| 67 | + raise ValueError("batch_size must be a positive integer") |
| 68 | + |
| 69 | + semaphore = Semaphore(25) |
| 70 | + vector_batches = [ |
| 71 | + vectors[i : i + batch_size] for i in range(0, len(vectors), batch_size) |
| 72 | + ] |
| 73 | + tasks = [ |
| 74 | + self._upsert_batch( |
| 75 | + vectors=batch, namespace=namespace, timeout=100, semaphore=semaphore |
| 76 | + ) |
| 77 | + for batch in vector_batches |
| 78 | + ] |
| 79 | + |
| 80 | + return await tqdm.gather(*tasks, disable=not show_progress, desc="Upserted batches") |
| 81 | + |
| 82 | + async def _upsert_batch( |
| 83 | + self, |
| 84 | + vectors: List[GRPCVector], |
| 85 | + namespace: Optional[str], |
| 86 | + timeout: Optional[int] = None, |
| 87 | + semaphore: Optional[Semaphore] = None, |
| 88 | + **kwargs, |
| 89 | + ) -> Awaitable[UpsertResponse]: |
| 90 | + args_dict = parse_non_empty_args([("namespace", namespace)]) |
| 91 | + request = UpsertRequest(vectors=vectors, **args_dict) |
| 92 | + if semaphore is not None: |
| 93 | + async with semaphore: |
| 94 | + return await self.runner.run_asyncio( |
| 95 | + self.stub.Upsert, request, timeout=timeout, **kwargs |
| 96 | + ) |
| 97 | + else: |
| 98 | + return await self.runner.run_asyncio( |
| 99 | + self.stub.Upsert, request, timeout=timeout, **kwargs |
| 100 | + ) |
0 commit comments