Skip to content

Commit

Permalink
Tweak ergonomics
Browse files Browse the repository at this point in the history
  • Loading branch information
anticorrelator committed Aug 16, 2024
1 parent 9d23413 commit 151af78
Showing 1 changed file with 34 additions and 19 deletions.
53 changes: 34 additions & 19 deletions packages/phoenix-otel/src/phoenix/otel/otel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import warnings
from typing import Any, Optional
from typing import Any, Dict, Optional
from urllib.parse import ParseResult, urlparse

from openinference.semconv.resource import ResourceAttributes as _ResourceAttributes
Expand All @@ -24,7 +24,12 @@


def register(
endpoint: Optional[str] = None, project_name: Optional[str] = None, batch: bool = False
*,
endpoint: Optional[str] = None,
project_name: Optional[str] = None,
batch: bool = False,
set_global=True,
headers=None,
) -> _TracerProvider:
"""
Globally sets an OpenTelemetry TracerProvider for enabling OpenInference tracing.
Expand All @@ -49,11 +54,13 @@ def register(
tracer_provider = _TracerProvider(resource=resource)
span_processor: SpanProcessor
if batch:
span_processor = BatchSpanProcessor(endpoint=endpoint)
span_processor = BatchSpanProcessor(endpoint=endpoint, headers=headers)
else:
span_processor = SimpleSpanProcessor(endpoint=endpoint)
span_processor = SimpleSpanProcessor(endpoint=endpoint, headers=headers)
tracer_provider.add_span_processor(span_processor)
trace_api.set_tracer_provider(tracer_provider)

if set_global:
trace_api.set_tracer_provider(tracer_provider)
return tracer_provider


Expand Down Expand Up @@ -99,38 +106,48 @@ def add_span_processor(self, *args: Any, **kwargs: Any) -> None:


class SimpleSpanProcessor(_SimpleSpanProcessor):
def __init__(self, exporter: Optional[SpanExporter] = None, endpoint: Optional[str] = None):
def __init__(
self,
exporter: Optional[SpanExporter] = None,
endpoint: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
):
if exporter is None:
endpoint = endpoint or get_env_collector_endpoint()
parsed_url = urlparse(endpoint)
assert isinstance(parsed_url, ParseResult)
if _maybe_http_endpoint(parsed_url):
print("Exporting spans via HTTP.")
exporter = HTTPSpanExporter(endpoint=endpoint)
exporter = HTTPSpanExporter(endpoint=endpoint, headers=headers)
elif _maybe_grpc_endpoint(parsed_url):
print("Exporting spans via GRPC.")
exporter = GRPCSpanExporter(endpoint=endpoint)
exporter = GRPCSpanExporter(endpoint=endpoint, headers=headers)
else:
warnings.warn("Could not infer collector endpoint protocol, defaulting to HTTP.")
exporter = HTTPSpanExporter(endpoint=endpoint)
exporter = HTTPSpanExporter(endpoint=endpoint, headers=headers)
super().__init__(exporter)


class BatchSpanProcessor(_BatchSpanProcessor):
def __init__(self, exporter: Optional[SpanExporter] = None, endpoint: Optional[str] = None):
def __init__(
self,
exporter: Optional[SpanExporter] = None,
endpoint: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
):
if exporter is None:
endpoint = endpoint or get_env_collector_endpoint()
parsed_url = urlparse(endpoint)
assert isinstance(parsed_url, ParseResult)
if _maybe_http_endpoint(parsed_url):
print("Exporting spans via HTTP.")
exporter = HTTPSpanExporter(endpoint=endpoint)
exporter = HTTPSpanExporter(endpoint=endpoint, headers=headers)
elif _maybe_grpc_endpoint(parsed_url):
print("Exporting spans via GRPC.")
exporter = GRPCSpanExporter(endpoint=endpoint)
exporter = GRPCSpanExporter(endpoint=endpoint, headers=headers)
else:
warnings.warn("Could not infer collector endpoint protocol, defaulting to HTTP.")
exporter = HTTPSpanExporter(endpoint=endpoint)
exporter = HTTPSpanExporter(endpoint=endpoint, headers=headers)
super().__init__(exporter)


Expand All @@ -140,9 +157,8 @@ def __init__(self, *args: Any, **kwargs: Any):
bound_args = sig.bind_partial(*args, **kwargs)
bound_args.apply_defaults()

phoenix_headers = get_env_client_headers()
if phoenix_headers:
bound_args.arguments["headers"] = phoenix_headers
if not bound_args.arguments.get("headers"):
bound_args.arguments["headers"] = get_env_client_headers()

if bound_args.arguments.get("endpoint") is None:
bound_args.arguments["endpoint"] = get_env_collector_endpoint()
Expand All @@ -155,9 +171,8 @@ def __init__(self, *args: Any, **kwargs: Any):
bound_args = sig.bind_partial(*args, **kwargs)
bound_args.apply_defaults()

phoenix_headers = get_env_client_headers()
if phoenix_headers:
bound_args.arguments["headers"] = phoenix_headers
if not bound_args.arguments.get("headers"):
bound_args.arguments["headers"] = get_env_client_headers()

if bound_args.arguments.get("endpoint") is None:
bound_args.arguments["endpoint"] = get_env_collector_endpoint()
Expand Down

0 comments on commit 151af78

Please sign in to comment.