|
63 | 63 | AgentExecutor = None |
64 | 64 |
|
65 | 65 |
|
| 66 | +# Conditional imports for embeddings providers |
| 67 | +try: |
| 68 | + from langchain_openai import OpenAIEmbeddings # type: ignore[import-not-found] |
| 69 | +except ImportError: |
| 70 | + OpenAIEmbeddings = None |
| 71 | + |
| 72 | +try: |
| 73 | + from langchain_openai import AzureOpenAIEmbeddings |
| 74 | +except ImportError: |
| 75 | + AzureOpenAIEmbeddings = None |
| 76 | + |
| 77 | +try: |
| 78 | + from langchain_google_vertexai import VertexAIEmbeddings # type: ignore[import-not-found] |
| 79 | +except ImportError: |
| 80 | + VertexAIEmbeddings = None |
| 81 | + |
| 82 | +try: |
| 83 | + from langchain_aws import BedrockEmbeddings # type: ignore[import-not-found] |
| 84 | +except ImportError: |
| 85 | + BedrockEmbeddings = None |
| 86 | + |
| 87 | +try: |
| 88 | + from langchain_cohere import CohereEmbeddings # type: ignore[import-not-found] |
| 89 | +except ImportError: |
| 90 | + CohereEmbeddings = None |
| 91 | + |
| 92 | +try: |
| 93 | + from langchain_mistralai import MistralAIEmbeddings # type: ignore[import-not-found] |
| 94 | +except ImportError: |
| 95 | + MistralAIEmbeddings = None |
| 96 | + |
| 97 | +try: |
| 98 | + from langchain_huggingface import HuggingFaceEmbeddings # type: ignore[import-not-found] |
| 99 | +except ImportError: |
| 100 | + HuggingFaceEmbeddings = None |
| 101 | + |
| 102 | +try: |
| 103 | + from langchain_ollama import OllamaEmbeddings # type: ignore[import-not-found] |
| 104 | +except ImportError: |
| 105 | + OllamaEmbeddings = None |
| 106 | + |
| 107 | + |
66 | 108 | DATA_FIELDS = { |
67 | 109 | "frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY, |
68 | 110 | "function_call": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS, |
@@ -140,6 +182,16 @@ def setup_once(): |
140 | 182 | AgentExecutor.invoke = _wrap_agent_executor_invoke(AgentExecutor.invoke) |
141 | 183 | AgentExecutor.stream = _wrap_agent_executor_stream(AgentExecutor.stream) |
142 | 184 |
|
| 185 | + # Patch embeddings providers |
| 186 | + _patch_embeddings_provider(OpenAIEmbeddings) |
| 187 | + _patch_embeddings_provider(AzureOpenAIEmbeddings) |
| 188 | + _patch_embeddings_provider(VertexAIEmbeddings) |
| 189 | + _patch_embeddings_provider(BedrockEmbeddings) |
| 190 | + _patch_embeddings_provider(CohereEmbeddings) |
| 191 | + _patch_embeddings_provider(MistralAIEmbeddings) |
| 192 | + _patch_embeddings_provider(HuggingFaceEmbeddings) |
| 193 | + _patch_embeddings_provider(OllamaEmbeddings) |
| 194 | + |
143 | 195 |
|
144 | 196 | class WatchedSpan: |
145 | 197 | span = None # type: Span |
@@ -976,3 +1028,105 @@ async def new_iterator_async(): |
976 | 1028 | return result |
977 | 1029 |
|
978 | 1030 | return new_stream |
| 1031 | + |
| 1032 | + |
| 1033 | +def _patch_embeddings_provider(provider_class): |
| 1034 | + # type: (Any) -> None |
| 1035 | + """Patch an embeddings provider class with monitoring wrappers.""" |
| 1036 | + if provider_class is None: |
| 1037 | + return |
| 1038 | + |
| 1039 | + if hasattr(provider_class, "embed_documents"): |
| 1040 | + provider_class.embed_documents = _wrap_embedding_method( |
| 1041 | + provider_class.embed_documents |
| 1042 | + ) |
| 1043 | + if hasattr(provider_class, "embed_query"): |
| 1044 | + provider_class.embed_query = _wrap_embedding_method(provider_class.embed_query) |
| 1045 | + if hasattr(provider_class, "aembed_documents"): |
| 1046 | + provider_class.aembed_documents = _wrap_async_embedding_method( |
| 1047 | + provider_class.aembed_documents |
| 1048 | + ) |
| 1049 | + if hasattr(provider_class, "aembed_query"): |
| 1050 | + provider_class.aembed_query = _wrap_async_embedding_method( |
| 1051 | + provider_class.aembed_query |
| 1052 | + ) |
| 1053 | + |
| 1054 | + |
| 1055 | +def _wrap_embedding_method(f): |
| 1056 | + # type: (Callable[..., Any]) -> Callable[..., Any] |
| 1057 | + """Wrap sync embedding methods (embed_documents and embed_query).""" |
| 1058 | + |
| 1059 | + @wraps(f) |
| 1060 | + def new_embedding_method(self, *args, **kwargs): |
| 1061 | + # type: (Any, Any, Any) -> Any |
| 1062 | + integration = sentry_sdk.get_client().get_integration(LangchainIntegration) |
| 1063 | + if integration is None: |
| 1064 | + return f(self, *args, **kwargs) |
| 1065 | + |
| 1066 | + model_name = getattr(self, "model", None) or getattr(self, "model_name", None) |
| 1067 | + with sentry_sdk.start_span( |
| 1068 | + op=OP.GEN_AI_EMBEDDINGS, |
| 1069 | + name=f"embeddings {model_name}" if model_name else "embeddings", |
| 1070 | + origin=LangchainIntegration.origin, |
| 1071 | + ) as span: |
| 1072 | + span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "embeddings") |
| 1073 | + if model_name: |
| 1074 | + span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name) |
| 1075 | + |
| 1076 | + # Capture input if PII is allowed |
| 1077 | + if ( |
| 1078 | + should_send_default_pii() |
| 1079 | + and integration.include_prompts |
| 1080 | + and len(args) > 0 |
| 1081 | + ): |
| 1082 | + input_data = args[0] |
| 1083 | + # Normalize to list format |
| 1084 | + texts = input_data if isinstance(input_data, list) else [input_data] |
| 1085 | + set_data_normalized( |
| 1086 | + span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, texts, unpack=False |
| 1087 | + ) |
| 1088 | + |
| 1089 | + result = f(self, *args, **kwargs) |
| 1090 | + return result |
| 1091 | + |
| 1092 | + return new_embedding_method |
| 1093 | + |
| 1094 | + |
| 1095 | +def _wrap_async_embedding_method(f): |
| 1096 | + # type: (Callable[..., Any]) -> Callable[..., Any] |
| 1097 | + """Wrap async embedding methods (aembed_documents and aembed_query).""" |
| 1098 | + |
| 1099 | + @wraps(f) |
| 1100 | + async def new_async_embedding_method(self, *args, **kwargs): |
| 1101 | + # type: (Any, Any, Any) -> Any |
| 1102 | + integration = sentry_sdk.get_client().get_integration(LangchainIntegration) |
| 1103 | + if integration is None: |
| 1104 | + return await f(self, *args, **kwargs) |
| 1105 | + |
| 1106 | + model_name = getattr(self, "model", None) or getattr(self, "model_name", None) |
| 1107 | + with sentry_sdk.start_span( |
| 1108 | + op=OP.GEN_AI_EMBEDDINGS, |
| 1109 | + name=f"embeddings {model_name}" if model_name else "embeddings", |
| 1110 | + origin=LangchainIntegration.origin, |
| 1111 | + ) as span: |
| 1112 | + span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "embeddings") |
| 1113 | + if model_name: |
| 1114 | + span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name) |
| 1115 | + |
| 1116 | + # Capture input if PII is allowed |
| 1117 | + if ( |
| 1118 | + should_send_default_pii() |
| 1119 | + and integration.include_prompts |
| 1120 | + and len(args) > 0 |
| 1121 | + ): |
| 1122 | + input_data = args[0] |
| 1123 | + # Normalize to list format |
| 1124 | + texts = input_data if isinstance(input_data, list) else [input_data] |
| 1125 | + set_data_normalized( |
| 1126 | + span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, texts, unpack=False |
| 1127 | + ) |
| 1128 | + |
| 1129 | + result = await f(self, *args, **kwargs) |
| 1130 | + return result |
| 1131 | + |
| 1132 | + return new_async_embedding_method |
0 commit comments