diff --git a/packages/phoenix-otel/src/phoenix/otel/otel.py b/packages/phoenix-otel/src/phoenix/otel/otel.py index 04aae6048c..d746d13a63 100644 --- a/packages/phoenix-otel/src/phoenix/otel/otel.py +++ b/packages/phoenix-otel/src/phoenix/otel/otel.py @@ -1,6 +1,6 @@ import inspect import warnings -from typing import Any, Optional +from typing import Any, Dict, Optional from urllib.parse import ParseResult, urlparse from openinference.semconv.resource import ResourceAttributes as _ResourceAttributes @@ -24,7 +24,12 @@ def register( - endpoint: Optional[str] = None, project_name: Optional[str] = None, batch: bool = False + *, + endpoint: Optional[str] = None, + project_name: Optional[str] = None, + batch: bool = False, + set_global=True, + headers=None, ) -> _TracerProvider: """ Globally sets an OpenTelemetry TracerProvider for enabling OpenInference tracing. @@ -49,11 +54,13 @@ def register( tracer_provider = _TracerProvider(resource=resource) span_processor: SpanProcessor if batch: - span_processor = BatchSpanProcessor(endpoint=endpoint) + span_processor = BatchSpanProcessor(endpoint=endpoint, headers=headers) else: - span_processor = SimpleSpanProcessor(endpoint=endpoint) + span_processor = SimpleSpanProcessor(endpoint=endpoint, headers=headers) tracer_provider.add_span_processor(span_processor) - trace_api.set_tracer_provider(tracer_provider) + + if set_global: + trace_api.set_tracer_provider(tracer_provider) return tracer_provider @@ -99,38 +106,48 @@ def add_span_processor(self, *args: Any, **kwargs: Any) -> None: class SimpleSpanProcessor(_SimpleSpanProcessor): - def __init__(self, exporter: Optional[SpanExporter] = None, endpoint: Optional[str] = None): + def __init__( + self, + exporter: Optional[SpanExporter] = None, + endpoint: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ): if exporter is None: endpoint = endpoint or get_env_collector_endpoint() parsed_url = urlparse(endpoint) assert isinstance(parsed_url, ParseResult) if _maybe_http_endpoint(parsed_url): print("Exporting spans via HTTP.") - exporter = HTTPSpanExporter(endpoint=endpoint) + exporter = HTTPSpanExporter(endpoint=endpoint, headers=headers) elif _maybe_grpc_endpoint(parsed_url): print("Exporting spans via GRPC.") - exporter = GRPCSpanExporter(endpoint=endpoint) + exporter = GRPCSpanExporter(endpoint=endpoint, headers=headers) else: warnings.warn("Could not infer collector endpoint protocol, defaulting to HTTP.") - exporter = HTTPSpanExporter(endpoint=endpoint) + exporter = HTTPSpanExporter(endpoint=endpoint, headers=headers) super().__init__(exporter) class BatchSpanProcessor(_BatchSpanProcessor): - def __init__(self, exporter: Optional[SpanExporter] = None, endpoint: Optional[str] = None): + def __init__( + self, + exporter: Optional[SpanExporter] = None, + endpoint: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ): if exporter is None: endpoint = endpoint or get_env_collector_endpoint() parsed_url = urlparse(endpoint) assert isinstance(parsed_url, ParseResult) if _maybe_http_endpoint(parsed_url): print("Exporting spans via HTTP.") - exporter = HTTPSpanExporter(endpoint=endpoint) + exporter = HTTPSpanExporter(endpoint=endpoint, headers=headers) elif _maybe_grpc_endpoint(parsed_url): print("Exporting spans via GRPC.") - exporter = GRPCSpanExporter(endpoint=endpoint) + exporter = GRPCSpanExporter(endpoint=endpoint, headers=headers) else: warnings.warn("Could not infer collector endpoint protocol, defaulting to HTTP.") - exporter = HTTPSpanExporter(endpoint=endpoint) + exporter = HTTPSpanExporter(endpoint=endpoint, headers=headers) super().__init__(exporter) @@ -140,9 +157,8 @@ def __init__(self, *args: Any, **kwargs: Any): bound_args = sig.bind_partial(*args, **kwargs) bound_args.apply_defaults() - phoenix_headers = get_env_client_headers() - if phoenix_headers: - bound_args.arguments["headers"] = phoenix_headers + if not bound_args.arguments.get("headers"): + bound_args.arguments["headers"] = get_env_client_headers() if bound_args.arguments.get("endpoint") is None: bound_args.arguments["endpoint"] = get_env_collector_endpoint() @@ -155,9 +171,8 @@ def __init__(self, *args: Any, **kwargs: Any): bound_args = sig.bind_partial(*args, **kwargs) bound_args.apply_defaults() - phoenix_headers = get_env_client_headers() - if phoenix_headers: - bound_args.arguments["headers"] = phoenix_headers + if not bound_args.arguments.get("headers"): + bound_args.arguments["headers"] = get_env_client_headers() if bound_args.arguments.get("endpoint") is None: bound_args.arguments["endpoint"] = get_env_collector_endpoint()