Skip to content

Commit

Permalink
Feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Nov 27, 2024
1 parent 37a104f commit 5fb09ef
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 39 deletions.
2 changes: 2 additions & 0 deletions libs/langchain/langchain/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import TYPE_CHECKING, Any

from langchain._api import create_importer
from langchain.embeddings.base import init_embeddings
from langchain.embeddings.cache import CacheBackedEmbeddings

if TYPE_CHECKING:
Expand Down Expand Up @@ -221,4 +222,5 @@ def __getattr__(name: str) -> Any:
"VertexAIEmbeddings",
"VoyageEmbeddings",
"XinferenceEmbeddings",
"init_embeddings",
]
80 changes: 45 additions & 35 deletions libs/langchain/langchain/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from importlib import util
from typing import Any, List, Optional, Tuple, Union

from langchain_core._api import beta
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import Runnable

Expand Down Expand Up @@ -35,11 +36,13 @@ def _parse_model_string(model_name: str) -> Tuple[str, str]:
Returns:
A tuple of (provider, model_name)
Examples:
>>> _parse_model_string("openai:text-embedding-3-small")
("openai", "text-embedding-3-small")
>>> _parse_model_string("bedrock:amazon.titan-embed-text-v1")
("bedrock", "amazon.titan-embed-text-v1")
.. code-block:: python
_parse_model_string("openai:text-embedding-3-small")
# Returns: ("openai", "text-embedding-3-small")
_parse_model_string("bedrock:amazon.titan-embed-text-v1")
# Returns: ("bedrock", "amazon.titan-embed-text-v1")
Raises:
ValueError: If the model string is not in the correct format or
Expand Down Expand Up @@ -89,7 +92,7 @@ def _infer_model_and_provider(
"Must specify either:\n"
"1. A model string in format 'provider:model-name'\n"
" Example: 'openai:text-embedding-3-small'\n"
"2. Or explicitly set model_provider from: "
"2. Or explicitly set provider from: "
f"{providers}"
)

Expand All @@ -102,13 +105,13 @@ def _infer_model_and_provider(
return provider, model_name


def _init_embedding_model_helper(
def _init_embeddings_helper(
model: str, *, provider: Optional[str] = None, **kwargs: Any
) -> Embeddings:
"""Initialize an Embeddings model from the model name and provider.
Internal helper function that handles the actual model initialization.
Use init_embedding_model() instead of calling this directly.
Use init_embeddings() instead of calling this directly.
"""
provider, model_name = _infer_model_and_provider(model, provider=provider)
pkg = _SUPPORTED_PROVIDERS[provider]
Expand Down Expand Up @@ -150,7 +153,7 @@ def _init_embedding_model_helper(
)


@functools.lru_cache
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
def _check_pkg(pkg: str) -> None:
"""Check if a package is installed."""
if not util.find_spec(pkg):
Expand All @@ -160,23 +163,23 @@ def _check_pkg(pkg: str) -> None:
)


def embedding_model(
@beta()
def init_embeddings(
model: str,
*,
model_provider: Optional[str] = None,
provider: Optional[str] = None,
**kwargs: Any,
) -> Union[Embeddings, Runnable[Any, List[float]]]:
f"""Initialize an embeddings model from a model name and optional provider.
"""Initialize an embeddings model from a model name and optional provider.
This function creates an embeddings model instance from either:
1. A model string in the format 'provider:model-name'
2. A model name and explicit provider
**Note:** Must have the integration package corresponding to the model provider
installed.
Args:
model: Name of the model to use. Can be either:
- A model string like "openai:text-embedding-3-small"
- Just the model name if model_provider is specified
model_provider: Optional explicit provider name. If not specified,
- Just the model name if provider is specified
provider: Optional explicit provider name. If not specified,
will attempt to parse from the model string. Supported providers
and their required packages:
Expand All @@ -192,22 +195,29 @@ def embedding_model(
ValueError: If the model provider is not supported or cannot be determined
ImportError: If the required provider package is not installed
Examples:
>>> # Using a model string
>>> model = init_embedding_model("openai:text-embedding-3-small")
>>> model.embed_query("Hello, world!")
>>> # Using explicit provider
>>> model = init_embedding_model(
... model="text-embedding-3-small",
... model_provider="openai"
... )
>>> # With additional parameters
>>> model = init_embedding_model(
... "openai:text-embedding-3-small",
... api_key="sk-..."
... )
.. dropdown:: Example Usage
:open:
.. code-block:: python
# Using a model string
model = init_embeddings("openai:text-embedding-3-small")
model.embed_query("Hello, world!")
# Using explicit provider
model = init_embeddings(
model="text-embedding-3-small",
provider="openai"
)
model.embed_documents(["Hello, world!", "Goodbye, world!"])
# With additional parameters
model = init_embeddings(
"openai:text-embedding-3-small",
api_key="sk-..."
)
.. versionadded:: 0.3.9
"""
if not model:
providers = _SUPPORTED_PROVIDERS.keys()
Expand All @@ -216,10 +226,10 @@ def embedding_model(
f"Supported providers are: {', '.join(providers)}"
)

return _init_embedding_model_helper(model, model_provider=model_provider, **kwargs)
return _init_embeddings_helper(model, provider=provider, **kwargs)


__all__ = [
"embedding_model",
"init_embeddings",
"Embeddings", # This one is for backwards compatibility
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from langchain_core.embeddings import Embeddings

from langchain.embeddings.base import _SUPPORTED_PROVIDERS, embedding_model
from langchain.embeddings.base import _SUPPORTED_PROVIDERS, init_embeddings


@pytest.mark.parametrize(
Expand All @@ -24,12 +24,12 @@ async def test_init_embedding_model(provider: str, model: str) -> None:
except ImportError:
pytest.skip(f"Package {package} is not installed")

model_colon = embedding_model(f"{provider}:{model}")
model_colon = init_embeddings(f"{provider}:{model}")
assert isinstance(model_colon, Embeddings)

model_explicit = embedding_model(
model_explicit = init_embeddings(
model=model,
model_provider=provider,
provider=provider,
)
assert isinstance(model_explicit, Embeddings)

Expand Down
1 change: 1 addition & 0 deletions libs/langchain/tests/unit_tests/embeddings/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"JohnSnowLabsEmbeddings",
"VoyageEmbeddings",
"BookendEmbeddings",
"init_embeddings",
]


Expand Down

0 comments on commit 5fb09ef

Please sign in to comment.