Skip to content

Commit

Permalink
fix: Pass the project ID from vertexai.init to CloudTraceSpanExporter…
Browse files Browse the repository at this point in the history
… when enable_tracing=True for LangchainAgent

PiperOrigin-RevId: 652617521
  • Loading branch information
yeesian authored and copybara-github committed Jul 15, 2024
1 parent ecc4f09 commit 3ec043e
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 12 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@

reasoning_engine_extra_require = [
"cloudpickle >= 3.0, < 4.0",
"google-cloud-trace < 2",
"opentelemetry-sdk < 2",
"opentelemetry-exporter-gcp-trace < 2",
"pydantic >= 2.6.3, < 3",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from unittest import mock

from google import auth
from google.auth import credentials as auth_credentials
import vertexai
from google.cloud.aiplatform import initializer
from vertexai.preview import reasoning_engines
Expand Down Expand Up @@ -65,8 +64,10 @@ def place_photo_query(
@pytest.fixture(scope="module")
def google_auth_mock():
with mock.patch.object(auth, "default") as google_auth_mock:
credentials_mock = mock.Mock()
credentials_mock.with_quota_project.return_value = None
google_auth_mock.return_value = (
auth_credentials.AnonymousCredentials(),
credentials_mock,
_TEST_PROJECT,
)
yield google_auth_mock
Expand Down
28 changes: 18 additions & 10 deletions vertexai/preview/reasoning_engines/templates/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,30 +397,38 @@ def set_up(self):
from vertexai.reasoning_engines import _utils

cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
openinference_langchain = _utils._import_openinference_langchain_or_warn()
opentelemetry = _utils._import_opentelemetry_or_warn()
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
if all(
(
cloud_trace_exporter,
cloud_trace_v2,
openinference_langchain,
opentelemetry,
opentelemetry_sdk_trace,
)
):
import google.auth

credentials, _ = google.auth.default()
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
project_id=self._project,
client=cloud_trace_v2.TraceServiceClient(
credentials=credentials.with_quota_project(self._project),
),
)
span_processor = opentelemetry_sdk_trace.export.SimpleSpanProcessor(
span_exporter=span_exporter,
)
tracer_provider = opentelemetry.trace.get_tracer_provider()
if tracer_provider and _utils._is_noop_tracer_provider(tracer_provider):
# Set a trace provider if it has not been set.
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
project_id=self._project,
)
span_processor = opentelemetry_sdk_trace.export.SimpleSpanProcessor(
span_exporter=span_exporter,
)
tracer_provider = opentelemetry_sdk_trace.TracerProvider(
active_span_processor=span_processor,
)
# Avoids AttributeError: 'ProxyTracerProvider' object has no
# attribute 'add_span_processor'
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
tracer_provider.add_span_processor(span_processor)
self._instrumentor = openinference_langchain.LangChainInstrumentor()
self._instrumentor.instrument()
else:
Expand Down
14 changes: 14 additions & 0 deletions vertexai/reasoning_engines/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,20 @@ def _import_opentelemetry_sdk_trace_or_warn() -> Optional[types.ModuleType]:
return None


def _import_cloud_trace_v2_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the google.cloud.trace_v2 module."""
try:
import google.cloud.trace_v2

return google.cloud.trace_v2
except ImportError:
_LOGGER.warning(
"google-cloud-trace is not installed. Please call "
"'pip install google-cloud-aiplatform[reasoningengine]'."
)
return None


def _import_cloud_trace_exporter_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the opentelemetry.exporter.cloud_trace module."""
try:
Expand Down

0 comments on commit 3ec043e

Please sign in to comment.