Skip to content

Commit

Permalink
[OPIK-444] [SDK] Update OpenAI integration to support logging model a…
Browse files Browse the repository at this point in the history
…nd provider in a new format (#766)

* refactor SpanData creation

* add fields to StartSpanParameters

* add `provider` attribute to base decorator

* handle `provider` and `model` attributes in openai track decorator

* fix linter warnings

* fix linter warnings

* update e2e tests

* remove span response payload workaround

* move handling of distributed_trace_headers to _create_span()

* move _create_span() to helpers

* use only hostname during openai provider detection

* use less strict check for openai model version in tests

* fix linter warnings
  • Loading branch information
japdubengsub authored Dec 2, 2024
1 parent 6ecd7bc commit 330ea28
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 96 deletions.
16 changes: 1 addition & 15 deletions sdks/python/src/opik/api_objects/opik_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,21 +632,7 @@ def get_span_content(self, id: str) -> span_public.SpanPublic:
span_public.SpanPublic: pydantic model object with all the data associated with the span found.
Raises an error if span was not found.
"""
result = self._rest_client.spans.get_span_by_id(id)

# fixme temporary fix for wrong response payload
# because span_public.SpanPublic is frozen we will create a copy and update it
new_values: Dict[str, Any] = {}

if result.model == "":
new_values["model"] = None
if result.provider == "":
new_values["provider"] = None

if len(new_values) > 0:
result = result.model_copy(update=new_values)

return result
return self._rest_client.spans.get_span_by_id(id)

def get_project(self, id: str) -> project_public.ProjectPublic:
"""
Expand Down
27 changes: 27 additions & 0 deletions sdks/python/src/opik/decorator/arguments_helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Optional, Any, Dict, List, Callable

from .. import datetime_helpers
from ..api_objects import helpers, span
from ..types import SpanType

import dataclasses
Expand Down Expand Up @@ -45,6 +48,8 @@ class StartSpanParameters(BaseArguments):
metadata: Optional[Dict[str, Any]] = None
input: Optional[Dict[str, Any]] = None
project_name: Optional[str] = None
model: Optional[str] = None
provider: Optional[str] = None


@dataclasses.dataclass
Expand All @@ -63,3 +68,25 @@ class TrackOptions(BaseArguments):
generations_aggregator: Optional[Callable[[List[Any]], Any]]
flush: bool
project_name: Optional[str]


def create_span_data(
start_span_arguments: StartSpanParameters,
trace_id: str,
parent_span_id: Optional[str] = None,
) -> span.SpanData:
span_data = span.SpanData(
id=helpers.generate_id(),
parent_span_id=parent_span_id,
trace_id=trace_id,
start_time=datetime_helpers.local_timestamp(),
name=start_span_arguments.name,
type=start_span_arguments.type,
tags=start_span_arguments.tags,
metadata=start_span_arguments.metadata,
input=start_span_arguments.input,
project_name=start_span_arguments.project_name,
model=start_span_arguments.model,
provider=start_span_arguments.provider,
)
return span_data
86 changes: 30 additions & 56 deletions sdks/python/src/opik/decorator/base_track_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class BaseTrackDecorator(abc.ABC):
Overriding other methods of this class is not recommended.
"""

def __init__(self) -> None:
self.provider: Optional[str] = None
""" Name of the LLM provider. Used in subclasses in integrations track decorators. """

def track(
self,
name: Optional[Union[Callable, str]] = None,
Expand Down Expand Up @@ -239,12 +243,10 @@ def _before_call(
kwargs=kwargs,
)

if opik_distributed_trace_headers is None:
self._create_span(start_span_arguments)
else:
self._create_distributed_node_root_span(
start_span_arguments, opik_distributed_trace_headers
)
self._create_span(
start_span_arguments,
opik_distributed_trace_headers,
)

except Exception as exception:
LOGGER.error(
Expand All @@ -256,17 +258,28 @@ def _before_call(
)

def _create_span(
self, start_span_arguments: arguments_helpers.StartSpanParameters
self,
start_span_arguments: arguments_helpers.StartSpanParameters,
distributed_trace_headers: Optional[DistributedTraceHeadersDict] = None,
) -> None:
"""
Handles different span creation flows.
"""
current_span_data = context_storage.top_span_data()
current_trace_data = context_storage.get_trace_data()

span_data: span.SpanData
trace_data: trace.TraceData

if distributed_trace_headers:
span_data = arguments_helpers.create_span_data(
start_span_arguments=start_span_arguments,
parent_span_id=distributed_trace_headers["opik_parent_span_id"],
trace_id=distributed_trace_headers["opik_trace_id"],
)
context_storage.add_span_data(span_data)
return

current_span_data = context_storage.top_span_data()
current_trace_data = context_storage.get_trace_data()

if current_span_data is not None:
# There is already at least one span in current context.
# Simply attach a new span to it.
Expand All @@ -280,17 +293,10 @@ def _create_span(

start_span_arguments.project_name = project_name

span_data = span.SpanData(
id=helpers.generate_id(),
span_data = arguments_helpers.create_span_data(
start_span_arguments=start_span_arguments,
parent_span_id=current_span_data.id,
trace_id=current_span_data.trace_id,
start_time=datetime_helpers.local_timestamp(),
name=start_span_arguments.name,
type=start_span_arguments.type,
tags=start_span_arguments.tags,
metadata=start_span_arguments.metadata,
input=start_span_arguments.input,
project_name=start_span_arguments.project_name,
)
context_storage.add_span_data(span_data)
return
Expand All @@ -309,17 +315,10 @@ def _create_span(

start_span_arguments.project_name = project_name

span_data = span.SpanData(
id=helpers.generate_id(),
span_data = arguments_helpers.create_span_data(
start_span_arguments=start_span_arguments,
parent_span_id=None,
trace_id=current_trace_data.id,
start_time=datetime_helpers.local_timestamp(),
name=start_span_arguments.name,
type=start_span_arguments.type,
tags=start_span_arguments.tags,
metadata=start_span_arguments.metadata,
input=start_span_arguments.input,
project_name=start_span_arguments.project_name,
)
context_storage.add_span_data(span_data)
return
Expand All @@ -338,41 +337,16 @@ def _create_span(
)
TRACES_CREATED_BY_DECORATOR.add(trace_data.id)

span_data = span.SpanData(
id=helpers.generate_id(),
span_data = arguments_helpers.create_span_data(
start_span_arguments=start_span_arguments,
parent_span_id=None,
trace_id=trace_data.id,
start_time=datetime_helpers.local_timestamp(),
name=start_span_arguments.name,
type=start_span_arguments.type,
tags=start_span_arguments.tags,
metadata=start_span_arguments.metadata,
input=start_span_arguments.input,
project_name=start_span_arguments.project_name,
)

context_storage.set_trace_data(trace_data)
context_storage.add_span_data(span_data)
return

def _create_distributed_node_root_span(
self,
start_span_arguments: arguments_helpers.StartSpanParameters,
distributed_trace_headers: DistributedTraceHeadersDict,
) -> None:
span_data = span.SpanData(
id=helpers.generate_id(),
parent_span_id=distributed_trace_headers["opik_parent_span_id"],
trace_id=distributed_trace_headers["opik_trace_id"],
name=start_span_arguments.name,
input=start_span_arguments.input,
metadata=start_span_arguments.metadata,
tags=start_span_arguments.tags,
type=start_span_arguments.type,
project_name=start_span_arguments.project_name,
)
context_storage.add_span_data(span_data)

def _after_call(
self,
output: Optional[Any],
Expand Down Expand Up @@ -438,7 +412,7 @@ def _generators_handler(
However, sometimes the function might return an instance of some specific class which
is not a python generator itself, but implements some API for iterating through data chunks.
In that case `_generators_handler` must be fully overriden in the subclass.
In that case `_generators_handler` must be fully overridden in the subclass.
This is usually the case when creating an integration with some LLM library.
"""
Expand Down
9 changes: 9 additions & 0 deletions sdks/python/src/opik/integrations/openai/openai_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class OpenaiTrackDecorator(base_track_decorator.BaseTrackDecorator):
openai.Stream and openai.AsyncStream objects.
"""

def __init__(self) -> None:
super().__init__()
self.provider = "openai"

def _start_span_inputs_preprocessor(
self,
func: Callable,
Expand Down Expand Up @@ -77,6 +81,8 @@ def _start_span_inputs_preprocessor(
tags=tags,
metadata=metadata,
project_name=track_options.project_name,
model=kwargs.get("model", None),
provider=self.provider,
)

return result
Expand All @@ -92,11 +98,14 @@ def _end_span_inputs_preprocessor(
result_dict = output.model_dump(mode="json")
output, metadata = dict_utils.split_dict_by_keys(result_dict, ["choices"])
usage = result_dict["usage"]
model = result_dict["model"]

result = arguments_helpers.EndSpanParameters(
output=output,
usage=usage,
metadata=metadata,
model=model,
provider=self.provider,
)

return result
Expand Down
3 changes: 3 additions & 0 deletions sdks/python/src/opik/integrations/openai/opik_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def track_openai(

decorator_factory = openai_decorator.OpenaiTrackDecorator()

if openai_client.base_url.host != "api.openai.com":
decorator_factory.provider = openai_client.base_url.host

completions_create_decorator = decorator_factory.track(
type="llm",
name="chat_completion_create",
Expand Down
Loading

0 comments on commit 330ea28

Please sign in to comment.