Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: eliminate interference on global tracer provider #2998

Merged
merged 5 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"starlette",
"uvicorn",
"psutil",
"strawberry-graphql==0.208.2",
"strawberry-graphql==0.227.2", # need to pin version because we're monkey-patching
"pyarrow",
"typing-extensions>=4.5; python_version<'3.12'",
# A minimum version of typing-extensions==4.6.0 is needed to avoid this issue on Python 3.12: https://github.com/Azure/azure-sdk-for-python/issues/33442#issuecomment-1847886784
Expand Down Expand Up @@ -73,7 +73,7 @@ dev = [
"pytest-postgresql",
"asyncpg",
"psycopg[binary]",
"strawberry-graphql[debug-server]==0.208.2",
"strawberry-graphql[debug-server,opentelemetry]==0.227.2", # need to pin version because we're monkey-patching
"pre-commit",
"arize[AutoEmbeddings, LLM_Evaluation]",
"llama-index>=0.10.3",
Expand Down Expand Up @@ -104,7 +104,7 @@ container = [
"opentelemetry-exporter-otlp",
"opentelemetry-instrumentation-starlette",
"opentelemetry-instrumentation-sqlalchemy",
"strawberry-graphql[opentelemetry]",
"strawberry-graphql[opentelemetry]==0.227.2", # need to pin version because we're monkey-patching
]

[project.urls]
Expand Down Expand Up @@ -175,7 +175,7 @@ dependencies = [
"opentelemetry-exporter-otlp",
"opentelemetry-instrumentation-starlette",
"opentelemetry-instrumentation-sqlalchemy",
"strawberry-graphql[opentelemetry]",
"strawberry-graphql[opentelemetry]==0.227.2", # need to pin version because we're monkey-patching
]

[tool.hatch.envs.style]
Expand Down Expand Up @@ -274,7 +274,7 @@ check = [

[tool.hatch.envs.gql]
dependencies = [
"strawberry-graphql[cli]==0.208.2",
"strawberry-graphql[cli]==0.227.2", # need to pin version because we're monkey-patching
"requests",
]

Expand Down
2 changes: 1 addition & 1 deletion src/phoenix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)


def is_server_instrumentation_enabled() -> bool:
def server_instrumentation_is_enabled() -> bool:
return bool(
os.getenv(ENV_PHOENIX_SERVER_INSTRUMENTATION_OTLP_TRACE_COLLECTOR_HTTP_ENDPOINT)
) or bool(os.getenv(ENV_PHOENIX_SERVER_INSTRUMENTATION_OTLP_TRACE_COLLECTOR_GRPC_ENDPOINT))
Expand Down
13 changes: 5 additions & 8 deletions src/phoenix/server/api/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from strawberry.types import Info
from typing_extensions import Annotated

from phoenix.config import DEFAULT_PROJECT_NAME, is_server_instrumentation_enabled
from phoenix.config import DEFAULT_PROJECT_NAME
from phoenix.db import models
from phoenix.pointcloud.clustering import Hdbscan
from phoenix.server.api.context import Context
Expand Down Expand Up @@ -272,14 +272,11 @@ async def clear_project(self, info: Info[Context, None], id: GlobalID) -> Query:
return Query()


_extensions = []
if is_server_instrumentation_enabled():
from strawberry.extensions.tracing import OpenTelemetryExtension

_extensions.append(OpenTelemetryExtension)

# This is the schema for generating `schema.graphql`.
# See https://strawberry.rocks/docs/guides/schema-export
# It should be kept in sync with the server's runtime-initialized
# instance. To do so, search for the usage of `strawberry.Schema(...)`.
schema = strawberry.Schema(
query=Query,
mutation=Mutation,
extensions=_extensions,
)
49 changes: 40 additions & 9 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
AsyncIterator,
Expand All @@ -13,8 +14,10 @@
Optional,
Tuple,
Union,
cast,
)

import strawberry
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
Expand Down Expand Up @@ -42,7 +45,7 @@
from phoenix.config import (
DEFAULT_PROJECT_NAME,
SERVER_DIR,
is_server_instrumentation_enabled,
server_instrumentation_is_enabled,
)
from phoenix.core.model_schema import Model
from phoenix.db.bulk_inserter import BulkInserter
Expand All @@ -60,6 +63,7 @@
from phoenix.server.api.dataloaders.span_descendants import SpanDescendantsDataLoader
from phoenix.server.api.routers.v1 import V1_ROUTES
from phoenix.server.api.schema import schema
from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider
from phoenix.trace.schemas import Span

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -259,21 +263,45 @@ def create_app(
""
)
raise PhoenixMigrationError(msg) from e

if is_server_instrumentation_enabled():
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor

SQLAlchemyInstrumentor().instrument(engine=engine.sync_engine)

db = _db(engine)
bulk_inserter = BulkInserter(
db,
initial_batch_of_spans=initial_batch_of_spans,
initial_batch_of_evaluations=initial_batch_of_evaluations,
)
tracer_provider = None
strawberry_extensions = schema.get_extensions()
if server_instrumentation_is_enabled():
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.trace import TracerProvider
from strawberry.extensions.tracing import OpenTelemetryExtension

tracer_provider = initialize_opentelemetry_tracer_provider()
SQLAlchemyInstrumentor().instrument(
engine=engine.sync_engine,
tracer_provider=tracer_provider,
)
if TYPE_CHECKING:
# Type-check the class before monkey-patching its private attribute.
assert OpenTelemetryExtension._tracer

class _OpenTelemetryExtension(OpenTelemetryExtension):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
# Monkey-patch its private tracer to eliminate usage of the global
# TracerProvider, which in a notebook setting could be the one
# used by OpenInference.
self._tracer = cast(TracerProvider, tracer_provider).get_tracer("strawberry")

strawberry_extensions.append(_OpenTelemetryExtension)
graphql = GraphQLWithContext(
db=db,
schema=schema,
schema=strawberry.Schema(
query=schema.query,
mutation=schema.mutation,
subscription=schema.subscription,
extensions=strawberry_extensions,
),
model=model,
corpus=corpus,
export_path=export_path,
Expand All @@ -286,7 +314,6 @@ def create_app(
prometheus_middlewares = [Middleware(PrometheusMiddleware)]
else:
prometheus_middlewares = []

app = Starlette(
lifespan=_lifespan(bulk_inserter),
middleware=[
Expand Down Expand Up @@ -329,4 +356,8 @@ def create_app(
)
app.state.read_only = read_only
app.state.db = db
if tracer_provider:
from opentelemetry.instrumentation.starlette import StarletteInstrumentor

StarletteInstrumentor.instrument_app(app, tracer_provider=tracer_provider)
return app
7 changes: 0 additions & 7 deletions src/phoenix/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
get_env_port,
get_pids_path,
get_working_dir,
is_server_instrumentation_enabled,
)
from phoenix.core.model_schema_adapter import create_model_from_datasets
from phoenix.db import get_printable_db_url
Expand All @@ -31,7 +30,6 @@
UMAPParameters,
)
from phoenix.server.app import create_app
from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider
from phoenix.settings import Settings
from phoenix.trace.fixtures import (
TRACES_FIXTURES,
Expand Down Expand Up @@ -230,11 +228,6 @@ def _get_pid_file() -> Path:
initial_spans=fixture_spans,
initial_evaluations=fixture_evals,
)
if is_server_instrumentation_enabled():
from opentelemetry.instrumentation.starlette import StarletteInstrumentor

initialize_opentelemetry_tracer_provider()
StarletteInstrumentor.instrument_app(app)
server = Server(config=Config(app, host=host, port=port))
Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start()

Expand Down
9 changes: 6 additions & 3 deletions src/phoenix/server/telemetry.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import os
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from opentelemetry.trace import TracerProvider

from phoenix.config import (
ENV_PHOENIX_SERVER_INSTRUMENTATION_OTLP_TRACE_COLLECTOR_GRPC_ENDPOINT,
ENV_PHOENIX_SERVER_INSTRUMENTATION_OTLP_TRACE_COLLECTOR_HTTP_ENDPOINT,
)


def initialize_opentelemetry_tracer_provider() -> None:
from opentelemetry import trace as trace_api
def initialize_opentelemetry_tracer_provider() -> "TracerProvider":
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.trace.export import BatchSpanProcessor

Expand All @@ -28,4 +31,4 @@ def initialize_opentelemetry_tracer_provider() -> None:
)

tracer_provider.add_span_processor(BatchSpanProcessor(GrpcExporter(grpc_endpoint)))
trace_api.set_tracer_provider(tracer_provider)
return tracer_provider
Loading