Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: update API key handling to use get_secret_value method #1142

Merged
merged 2 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions pkgs/base/swarmauri_base/embeddings/EmbeddingBase.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Optional, Literal
from pydantic import Field
from typing import Literal, Optional

from swarmauri_core.embeddings.IVectorize import IVectorize
from pydantic import Field
from swarmauri_core.ComponentBase import ComponentBase, ResourceTypes
from swarmauri_core.embeddings.IFeature import IFeature
from swarmauri_core.embeddings.ISaveModel import ISaveModel
from swarmauri_core.ComponentBase import ComponentBase, ResourceTypes
from swarmauri_core.embeddings.IVectorize import IVectorize


@ComponentBase.register_model()
class EmbeddingBase(IVectorize, IFeature, ISaveModel, ComponentBase):
resource: Optional[str] = Field(default=ResourceTypes.EMBEDDING.value, frozen=True)
type: Literal['EmbeddingBase'] = 'EmbeddingBase'

resource: Optional[str] = Field(default=ResourceTypes.EMBEDDING.value, frozen=True)
type: Literal["EmbeddingBase"] = "EmbeddingBase"
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional
from pydantic import Field, BaseModel

from pydantic import BaseModel, Field, SecretStr
from swarmauri_core.vector_stores.ICloudVectorStore import ICloudVectorStore


Expand All @@ -9,7 +9,9 @@ class VectorStoreCloudMixin(ICloudVectorStore, BaseModel):
Mixin class for cloud-based vector stores.
"""

api_key: str
api_key: Optional[SecretStr] = Field(
None, description="API key for the cloud-based store"
)
collection_name: str
url: Optional[str] = Field(
None, description="URL of the cloud-based store to connect to"
Expand All @@ -20,4 +22,4 @@ class VectorStoreCloudMixin(ICloudVectorStore, BaseModel):
)
client: Optional[object] = Field(
None, description="Client object for interacting with the cloud-based store"
)
)
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
from typing import List, Union, Literal, Optional
import numpy as np
from annoy import AnnoyIndex
import os
from typing import List, Literal, Optional, Union

from swarmauri_core.ComponentBase import ComponentBase
from swarmauri_standard.documents.Document import Document
from swarmauri_embedding_doc2vec.Doc2VecEmbedding import Doc2VecEmbedding
from swarmauri_standard.distances.CosineDistance import CosineDistance

import numpy as np
from annoy import AnnoyIndex
from swarmauri_base.vector_stores.VectorStoreBase import VectorStoreBase
from swarmauri_base.vector_stores.VectorStoreCloudMixin import VectorStoreCloudMixin
from swarmauri_base.vector_stores.VectorStoreRetrieveMixin import (
VectorStoreRetrieveMixin,
)
from swarmauri_base.vector_stores.VectorStoreCloudMixin import VectorStoreCloudMixin
from swarmauri_base.vector_stores.VectorStoreSaveLoadMixin import (
VectorStoreSaveLoadMixin,
)
from swarmauri_core.ComponentBase import ComponentBase
from swarmauri_embedding_doc2vec.Doc2VecEmbedding import Doc2VecEmbedding
from swarmauri_standard.distances.CosineDistance import CosineDistance
from swarmauri_standard.documents.Document import Document


@ComponentBase.register_type(VectorStoreBase, "AnnoyVectorStore")
Expand All @@ -34,27 +33,19 @@ class AnnoyVectorStore(
"""

type: Literal["AnnoyVectorStore"] = "AnnoyVectorStore"
api_key: str = (
"not_required" # Annoy doesn't need an API key, but base class requires it
)

def __init__(self, **kwargs):
"""
Initialize the AnnoyVectorStore.
Args:
**kwargs: Additional keyword arguments.
"""
# Set default api_key if not provided
if "api_key" not in kwargs:
kwargs["api_key"] = "not_required"

super().__init__(**kwargs)
self._embedder = Doc2VecEmbedding(vector_size=self.vector_size)
self._distance = CosineDistance()
self.client = None
self._documents = (
{}
) # Store documents in memory since Annoy only stores vectors
self._documents = {} # Store documents in memory since Annoy only stores vectors
self._current_index = 0 # Track the next available index
self._id_to_index = {} # Map document IDs to Annoy indices
self._index_to_id = {} # Map Annoy indices to document IDs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def connect(self, **kwargs):
if self._client is None:
self._client = weaviate.connect_to_weaviate_cloud(
cluster_url=self.url,
auth_credentials=Auth.api_key(self.api_key),
auth_credentials=Auth.api_key(self.api_key.get_secret_value()),
headers=kwargs.get("headers", {}),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def delete(self):

"""
try:
pc = Pinecone(api_key=self.api_key)
pc = Pinecone(api_key=self.api_key.get_secret_value())
pc.delete_index(self.collection_name)
self.client = None
except Exception as e:
Expand All @@ -73,7 +73,7 @@ def connect(self, metric: Optional[str] = "cosine", cloud: Optional[str] = "aws"

"""
try:
pc = Pinecone(api_key=self.api_key)
pc = Pinecone(api_key=self.api_key.get_secret_value())
if not pc.has_index(self.collection_name):
pc.create_index(
name=self.collection_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def connect(self) -> None:
"""
if self.client is None:
self.client = QdrantClient(
api_key=self.api_key,
api_key=self.api_key.get_secret_value(),
url=self.url,
)

Expand Down
42 changes: 21 additions & 21 deletions pkgs/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import httpx
import logging
from typing import List, Literal, Optional, Union
from pydantic import PrivateAttr
from swarmauri_standard.vectors.Vector import Vector

import httpx
from pydantic import PrivateAttr, SecretStr
from swarmauri_base.embeddings.EmbeddingBase import EmbeddingBase
from swarmauri_core.ComponentBase import ComponentBase

@ComponentBase.register_type(EmbeddingBase, 'CohereEmbedding')
from swarmauri_standard.vectors.Vector import Vector


@ComponentBase.register_type(EmbeddingBase, "CohereEmbedding")
class CohereEmbedding(EmbeddingBase):
"""
A class for generating embeddings using the Cohere REST API.
Expand All @@ -17,7 +21,7 @@ class CohereEmbedding(EmbeddingBase):
Attributes:
type (Literal["CohereEmbedding"]): The type identifier for this embedding class.
model (str): The Cohere embedding model to use.
api_key (str): The API key for accessing the Cohere API.
api_key (SecretStr): The API key for accessing the Cohere API.
allowed_task_types (List[str]): List of supported task types for embeddings

Link to Allowed Models: https://docs.cohere.com/reference/embed
Expand Down Expand Up @@ -52,21 +56,16 @@ class CohereEmbedding(EmbeddingBase):

# Public attributes
model: str = "embed-english-v3.0"
api_key: str = None
api_key: SecretStr = None
task_type: str = "search_document"
embedding_types: Optional[str] = "float"
truncate: Optional[str] = "END"

# Private configuration attributes
_task_type: str = PrivateAttr("search_document")
_embedding_types: Optional[str] = PrivateAttr("float")
_truncate: Optional[str] = PrivateAttr("END")
_client: httpx.Client = PrivateAttr()

def __init__(
self,
api_key: str = None,
model: str = "embed-english-v3.0",
task_type: Optional[str] = "search_document",
embedding_types: Optional[str] = "float",
truncate: Optional[str] = "END",
**kwargs,
):
"""
Expand All @@ -85,29 +84,23 @@ def __init__(
"""
super().__init__(**kwargs)

if model not in self.allowed_models:
if self.model not in self.allowed_models:
raise ValueError(
f"Invalid model '{model}'. Allowed models are: {', '.join(self.allowed_models)}"
f"Invalid model '{self.model}'. Allowed models are: {', '.join(self.allowed_models)}"
)

if task_type not in self.allowed_task_types:
if self.task_type not in self.allowed_task_types:
raise ValueError(
f"Invalid task_type '{task_type}'. Allowed task types are: {', '.join(self.allowed_task_types)}"
f"Invalid task_type '{self.task_type}'. Allowed task types are: {', '.join(self.allowed_task_types)}"
)
if embedding_types not in self._allowed_embedding_types:
if self.embedding_types not in self._allowed_embedding_types:
raise ValueError(
f"Invalid embedding_types '{embedding_types}'. Allowed embedding types are: {', '.join(self._allowed_embedding_types)}"
f"Invalid embedding_types '{self.embedding_types}'. Allowed embedding types are: {', '.join(self._allowed_embedding_types)}"
)
if truncate not in ["END", "START", "NONE"]:
if self.truncate not in ["END", "START", "NONE"]:
raise ValueError(
f"Invalid truncate '{truncate}'. Allowed truncate are: END, START, NONE"
f"Invalid truncate '{self.truncate}'. Allowed truncate are: END, START, NONE"
)

self.model = model
self.api_key = api_key
self._task_type = task_type
self._embedding_types = embedding_types
self._truncate = truncate
self._client = httpx.Client()

def _make_request(self, payload: dict) -> dict:
Expand All @@ -126,9 +119,8 @@ def _make_request(self, payload: dict) -> dict:
headers = {
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"Bearer {self.api_key}",
"Authorization": f"Bearer {self.api_key.get_secret_value()}",
}

try:
response = self._client.post(
f"{self._BASE_URL}/embed", headers=headers, json=payload
Expand All @@ -155,22 +147,22 @@ def infer_vector(self, data: Union[List[str], List[str]]) -> List[Vector]:
# Prepare the payload based on input type
payload = {
"model": self.model,
"embedding_types": [self._embedding_types],
"embedding_types": [self.embedding_types],
}

if self._task_type == "image":
if self.task_type == "image":
payload["input_type"] = "image"
payload["images"] = data
else:
payload["input_type"] = self._task_type
payload["input_type"] = self.task_type
payload["texts"] = data
payload["truncate"] = self._truncate
payload["truncate"] = self.truncate

# Make the API request
response = self._make_request(payload)

# Extract embeddings from response
embeddings = response["embeddings"][self._embedding_types]
embeddings = response["embeddings"][self.embedding_types]
return [Vector(value=item) for item in embeddings]

except Exception as e:
Expand Down
Loading