Skip to content

Commit

Permalink
Init embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Nov 26, 2024
1 parent 4027da1 commit 3a75d7a
Show file tree
Hide file tree
Showing 4 changed files with 323 additions and 2 deletions.
224 changes: 222 additions & 2 deletions libs/langchain/langchain/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,224 @@
import functools
from importlib import util
from typing import Any, List, Optional, Tuple, Union

from langchain_core.embeddings import Embeddings
from langchain_core.runnables import Runnable

_SUPPORTED_PROVIDERS = {
"openai": "langchain_openai",
"azure_openai": "langchain_openai",
"google_vertexai": "langchain_google_vertexai",
"bedrock": "langchain_aws",
"cohere": "langchain_cohere",
"mistralai": "langchain_mistralai",
"huggingface": "langchain_huggingface",
}


def _get_provider_list() -> str:
"""Get formatted list of providers and their packages."""
return "\n".join(
f" - {p}: {pkg.replace('_', '-')}"
for p, pkg in sorted(_SUPPORTED_PROVIDERS.items())
)


def _parse_model_string(model_name: str) -> Tuple[str, str]:
"""Parse a model string into provider and model name components.
The model string should be in the format 'provider:model-name', where provider
is one of the supported providers.
Args:
model_name: A model string in the format 'provider:model-name'
Returns:
A tuple of (provider, model_name)
Examples:
>>> _parse_model_string("openai:text-embedding-3-small")
("openai", "text-embedding-3-small")
>>> _parse_model_string("bedrock:amazon.titan-embed-text-v1")
("bedrock", "amazon.titan-embed-text-v1")
Raises:
ValueError: If the model string is not in the correct format or
the provider is unsupported
"""
if ":" not in model_name:
providers = sorted(_SUPPORTED_PROVIDERS)
raise ValueError(
f"Invalid model format '{model_name}'.\n"
f"Model name must be in format 'provider:model-name'\n"
f"Example valid model strings:\n"
f" - openai:text-embedding-3-small\n"
f" - bedrock:amazon.titan-embed-text-v1\n"
f" - cohere:embed-english-v3.0\n"
f"Supported providers: {providers}"
)

provider, model = model_name.split(":", 1)
provider = provider.lower().strip()
model = model.strip()

if provider not in _SUPPORTED_PROVIDERS:
raise ValueError(
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)
if not model:
raise ValueError("Model name cannot be empty")
return provider, model


def _infer_model_and_provider(
model: str, *, provider: Optional[str] = None
) -> Tuple[str, str]:
if provider is None and ":" in model:
provider, model_name = _parse_model_string(model)
else:
provider = provider
model_name = model

if not provider:
providers = sorted(_SUPPORTED_PROVIDERS)
raise ValueError(
"Must specify either:\n"
"1. A model string in format 'provider:model-name'\n"
" Example: 'openai:text-embedding-3-small'\n"
"2. Or explicitly set model_provider from: "
f"{providers}"
)

if provider not in _SUPPORTED_PROVIDERS:
raise ValueError(
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)
return provider, model_name


def _init_embedding_model_helper(
model: str, *, provider: Optional[str] = None, **kwargs: Any
) -> Embeddings:
"""Initialize an Embeddings model from the model name and provider.
Internal helper function that handles the actual model initialization.
Use init_embedding_model() instead of calling this directly.
"""
provider, model_name = _infer_model_and_provider(model, provider=provider)
pkg = _SUPPORTED_PROVIDERS[provider]
_check_pkg(pkg)

if provider == "openai":
from langchain_openai import OpenAIEmbeddings

return OpenAIEmbeddings(model=model_name, **kwargs)
elif provider == "azure_openai":
from langchain_openai import AzureOpenAIEmbeddings

return AzureOpenAIEmbeddings(model=model_name, **kwargs)
elif provider == "google_vertexai":
from langchain_google_vertexai import VertexAIEmbeddings

return VertexAIEmbeddings(model=model_name, **kwargs)
elif provider == "bedrock":
from langchain_aws import BedrockEmbeddings

return BedrockEmbeddings(model_id=model_name, **kwargs)
elif provider == "cohere":
from langchain_cohere import CohereEmbeddings

return CohereEmbeddings(model=model_name, **kwargs)
elif provider == "mistralai":
from langchain_mistralai import MistralAIEmbeddings

return MistralAIEmbeddings(model=model_name, **kwargs)
elif provider == "huggingface":
from langchain_huggingface import HuggingFaceEmbeddings

return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
else:
raise ValueError(
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)


@functools.lru_cache
def _check_pkg(pkg: str) -> None:
"""Check if a package is installed."""
if not util.find_spec(pkg):
raise ImportError(
f"Could not import {pkg} python package. "
f"Please install it with `pip install {pkg}`"
)


def embedding_model(
model: str,
*,
model_provider: Optional[str] = None,
**kwargs: Any,
) -> Union[Embeddings, Runnable[Any, List[float]]]:
f"""Initialize an embeddings model from a model name and optional provider.
This function creates an embeddings model instance from either:
1. A model string in the format 'provider:model-name'
2. A model name and explicit provider
Args:
model: Name of the model to use. Can be either:
- A model string like "openai:text-embedding-3-small"
- Just the model name if model_provider is specified
model_provider: Optional explicit provider name. If not specified,
will attempt to parse from the model string. Supported providers
and their required packages:
{_get_provider_list()}
**kwargs: Additional model-specific parameters passed to the embedding model.
These vary by provider, see the provider-specific documentation for details.
Returns:
An Embeddings instance that can generate embeddings for text.
Raises:
ValueError: If the model provider is not supported or cannot be determined
ImportError: If the required provider package is not installed
Examples:
>>> # Using a model string
>>> model = init_embedding_model("openai:text-embedding-3-small")
>>> model.embed_query("Hello, world!")
>>> # Using explicit provider
>>> model = init_embedding_model(
... model="text-embedding-3-small",
... model_provider="openai"
... )
>>> # With additional parameters
>>> model = init_embedding_model(
... "openai:text-embedding-3-small",
... api_key="sk-..."
... )
"""
if not model:
providers = sorted(_SUPPORTED_PROVIDERS.keys())
raise ValueError(
"Must specify model name. "
f"Supported providers are: {', '.join(providers)}"
)

return _init_embedding_model_helper(model, model_provider=model_provider, **kwargs)


# This is for backwards compatibility
__all__ = ["Embeddings"]
__all__ = [
"embedding_model",
"Embeddings", # This one is for backwards compatibility
]
Empty file.
44 changes: 44 additions & 0 deletions libs/langchain/tests/integration_tests/embeddings/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Test embeddings base module."""

import importlib

import pytest
from langchain_core.embeddings import Embeddings

from langchain.embeddings.base import _SUPPORTED_PROVIDERS, embedding_model


@pytest.mark.parametrize(
"provider, model",
[
("openai", "text-embedding-3-large"),
("google_vertexai", "text-embedding-gecko@003"),
("bedrock", "amazon.titan-embed-text-v1"),
("cohere", "embed-english-v2.0"),
],
)
async def test_init_embedding_model(provider: str, model: str) -> None:
package = _SUPPORTED_PROVIDERS[provider]
try:
importlib.import_module(package)
except ImportError:
pytest.skip(f"Package {package} is not installed")

model_colon = embedding_model(f"{provider}:{model}")
assert isinstance(model_colon, Embeddings)

model_explicit = embedding_model(
model=model,
model_provider=provider,
)
assert isinstance(model_explicit, Embeddings)

text = "Hello world"

embedding_colon = await model_colon.aembed_query(text)
assert isinstance(embedding_colon, list)
assert all(isinstance(x, float) for x in embedding_colon)

embedding_explicit = await model_explicit.aembed_query(text)
assert isinstance(embedding_explicit, list)
assert all(isinstance(x, float) for x in embedding_explicit)
57 changes: 57 additions & 0 deletions libs/langchain/tests/unit_tests/embeddings/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Test embeddings base module."""

import pytest

from langchain.embeddings.base import _SUPPORTED_PROVIDERS, _parse_model_string


def test_parse_model_string() -> None:
"""Test parsing model strings into provider and model components."""
assert _parse_model_string("openai:text-embedding-3-small") == (
"openai",
"text-embedding-3-small",
)
assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == (
"bedrock",
"amazon.titan-embed-text-v1",
)
assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == (
"huggingface",
"BAAI/bge-base-en:v1.5",
)


def test_parse_model_string_errors() -> None:
"""Test error cases for model string parsing."""
with pytest.raises(ValueError, match="Model name must be"):
_parse_model_string("just-a-model-name")

with pytest.raises(ValueError, match="Invalid model format "):
_parse_model_string("")

with pytest.raises(ValueError, match="is not supported"):
_parse_model_string(":model-name")

with pytest.raises(ValueError, match="Model name cannot be empty"):
_parse_model_string("openai:")

with pytest.raises(
ValueError, match="Provider 'invalid-provider' is not supported"
):
_parse_model_string("invalid-provider:model-name")

for provider in _SUPPORTED_PROVIDERS:
with pytest.raises(ValueError, match=f"{provider}"):
_parse_model_string("invalid-provider:model-name")


@pytest.mark.parametrize(
"provider",
sorted(_SUPPORTED_PROVIDERS.keys()),
)
def test_supported_providers_package_names(provider: str) -> None:
"""Test that all supported providers have valid package names."""
package = _SUPPORTED_PROVIDERS[provider]
assert "-" not in package
assert package.startswith("langchain_")
assert package.islower()

0 comments on commit 3a75d7a

Please sign in to comment.