diff --git a/libs/langchain/langchain/embeddings/__init__.py b/libs/langchain/langchain/embeddings/__init__.py index a2f95c71e3dc5..08bb679a552b3 100644 --- a/libs/langchain/langchain/embeddings/__init__.py +++ b/libs/langchain/langchain/embeddings/__init__.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any from langchain._api import create_importer +from langchain.embeddings.base import init_embeddings from langchain.embeddings.cache import CacheBackedEmbeddings if TYPE_CHECKING: @@ -221,4 +222,5 @@ def __getattr__(name: str) -> Any: "VertexAIEmbeddings", "VoyageEmbeddings", "XinferenceEmbeddings", + "init_embeddings", ] diff --git a/libs/langchain/langchain/embeddings/base.py b/libs/langchain/langchain/embeddings/base.py index 1ff758579073a..c757b3b7108fb 100644 --- a/libs/langchain/langchain/embeddings/base.py +++ b/libs/langchain/langchain/embeddings/base.py @@ -2,6 +2,7 @@ from importlib import util from typing import Any, List, Optional, Tuple, Union +from langchain_core._api import beta from langchain_core.embeddings import Embeddings from langchain_core.runnables import Runnable @@ -35,11 +36,13 @@ def _parse_model_string(model_name: str) -> Tuple[str, str]: Returns: A tuple of (provider, model_name) - Examples: - >>> _parse_model_string("openai:text-embedding-3-small") - ("openai", "text-embedding-3-small") - >>> _parse_model_string("bedrock:amazon.titan-embed-text-v1") - ("bedrock", "amazon.titan-embed-text-v1") + .. code-block:: python + + _parse_model_string("openai:text-embedding-3-small") + # Returns: ("openai", "text-embedding-3-small") + + _parse_model_string("bedrock:amazon.titan-embed-text-v1") + # Returns: ("bedrock", "amazon.titan-embed-text-v1") Raises: ValueError: If the model string is not in the correct format or @@ -89,7 +92,7 @@ def _infer_model_and_provider( "Must specify either:\n" "1. A model string in format 'provider:model-name'\n" " Example: 'openai:text-embedding-3-small'\n" - "2. Or explicitly set model_provider from: " + "2. Or explicitly set provider from: " f"{providers}" ) @@ -102,13 +105,13 @@ def _infer_model_and_provider( return provider, model_name -def _init_embedding_model_helper( +def _init_embeddings_helper( model: str, *, provider: Optional[str] = None, **kwargs: Any ) -> Embeddings: """Initialize an Embeddings model from the model name and provider. Internal helper function that handles the actual model initialization. - Use init_embedding_model() instead of calling this directly. + Use init_embeddings() instead of calling this directly. """ provider, model_name = _infer_model_and_provider(model, provider=provider) pkg = _SUPPORTED_PROVIDERS[provider] @@ -150,7 +153,7 @@ def _init_embedding_model_helper( ) -@functools.lru_cache +@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS)) def _check_pkg(pkg: str) -> None: """Check if a package is installed.""" if not util.find_spec(pkg): @@ -160,23 +163,23 @@ def _check_pkg(pkg: str) -> None: ) -def embedding_model( +@beta() +def init_embeddings( model: str, *, - model_provider: Optional[str] = None, + provider: Optional[str] = None, **kwargs: Any, ) -> Union[Embeddings, Runnable[Any, List[float]]]: - f"""Initialize an embeddings model from a model name and optional provider. + """Initialize an embeddings model from a model name and optional provider. - This function creates an embeddings model instance from either: - 1. A model string in the format 'provider:model-name' - 2. A model name and explicit provider + **Note:** Must have the integration package corresponding to the model provider + installed. Args: model: Name of the model to use. Can be either: - A model string like "openai:text-embedding-3-small" - - Just the model name if model_provider is specified - model_provider: Optional explicit provider name. If not specified, + - Just the model name if provider is specified + provider: Optional explicit provider name. If not specified, will attempt to parse from the model string. Supported providers and their required packages: @@ -192,22 +195,29 @@ def embedding_model( ValueError: If the model provider is not supported or cannot be determined ImportError: If the required provider package is not installed - Examples: - >>> # Using a model string - >>> model = init_embedding_model("openai:text-embedding-3-small") - >>> model.embed_query("Hello, world!") - - >>> # Using explicit provider - >>> model = init_embedding_model( - ... model="text-embedding-3-small", - ... model_provider="openai" - ... ) - - >>> # With additional parameters - >>> model = init_embedding_model( - ... "openai:text-embedding-3-small", - ... api_key="sk-..." - ... ) + .. dropdown:: Example Usage + :open: + + .. code-block:: python + + # Using a model string + model = init_embeddings("openai:text-embedding-3-small") + model.embed_query("Hello, world!") + + # Using explicit provider + model = init_embeddings( + model="text-embedding-3-small", + provider="openai" + ) + model.embed_documents(["Hello, world!", "Goodbye, world!"]) + + # With additional parameters + model = init_embeddings( + "openai:text-embedding-3-small", + api_key="sk-..." + ) + + .. versionadded:: 0.3.9 """ if not model: providers = _SUPPORTED_PROVIDERS.keys() @@ -216,10 +226,10 @@ def embedding_model( f"Supported providers are: {', '.join(providers)}" ) - return _init_embedding_model_helper(model, model_provider=model_provider, **kwargs) + return _init_embeddings_helper(model, provider=provider, **kwargs) __all__ = [ - "embedding_model", + "init_embeddings", "Embeddings", # This one is for backwards compatibility ] diff --git a/libs/langchain/tests/integration_tests/embeddings/test_base.py b/libs/langchain/tests/integration_tests/embeddings/test_base.py index fb7f2583234a7..204754642fdf5 100644 --- a/libs/langchain/tests/integration_tests/embeddings/test_base.py +++ b/libs/langchain/tests/integration_tests/embeddings/test_base.py @@ -5,7 +5,7 @@ import pytest from langchain_core.embeddings import Embeddings -from langchain.embeddings.base import _SUPPORTED_PROVIDERS, embedding_model +from langchain.embeddings.base import _SUPPORTED_PROVIDERS, init_embeddings @pytest.mark.parametrize( @@ -24,12 +24,12 @@ async def test_init_embedding_model(provider: str, model: str) -> None: except ImportError: pytest.skip(f"Package {package} is not installed") - model_colon = embedding_model(f"{provider}:{model}") + model_colon = init_embeddings(f"{provider}:{model}") assert isinstance(model_colon, Embeddings) - model_explicit = embedding_model( + model_explicit = init_embeddings( model=model, - model_provider=provider, + provider=provider, ) assert isinstance(model_explicit, Embeddings) diff --git a/libs/langchain/tests/unit_tests/embeddings/test_imports.py b/libs/langchain/tests/unit_tests/embeddings/test_imports.py index c6d7a8207d1c5..b44acf1a6032d 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_imports.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_imports.py @@ -55,6 +55,7 @@ "JohnSnowLabsEmbeddings", "VoyageEmbeddings", "BookendEmbeddings", + "init_embeddings", ]