Skip to content

Commit

Permalink
feat(crewAI): Added trace config, context attributes, suppress tracin…
Browse files Browse the repository at this point in the history
…g for CrewAI (#851)
  • Loading branch information
shreyabsridhar authored Aug 13, 2024
1 parent bda858a commit 4ad22fa
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ instruments = [
"crewai >= 0.41.1",
]
test = [
"crewai == 0.36.0",
"crewai == 0.41.1",
"crewai-tools == 0.4.26",
"opentelemetry-sdk",
"responses",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from importlib import import_module
from typing import Any, Collection

from openinference.instrumentation import (
OITracer,
TraceConfig,
)
from openinference.instrumentation.crewai._wrappers import (
_ExecuteCoreWrapper,
_KickoffWrapper,
Expand All @@ -22,6 +26,7 @@ class CrewAIInstrumentor(BaseInstrumentor): # type: ignore
"_original_execute_core",
"_original_kickoff",
"_original_tool_use",
"_tracer",
)

def instrumentation_dependencies(self) -> Collection[str]:
Expand All @@ -30,25 +35,32 @@ def instrumentation_dependencies(self) -> Collection[str]:
def _instrument(self, **kwargs: Any) -> None:
if not (tracer_provider := kwargs.get("tracer_provider")):
tracer_provider = trace_api.get_tracer_provider()
tracer = trace_api.get_tracer(__name__, __version__, tracer_provider)
if not (config := kwargs.get("config")):
config = TraceConfig()
else:
assert isinstance(config, TraceConfig)
self._tracer = OITracer(
trace_api.get_tracer(__name__, __version__, tracer_provider),
config=config,
)

execute_core_wrapper = _ExecuteCoreWrapper(tracer=tracer)
execute_core_wrapper = _ExecuteCoreWrapper(tracer=self._tracer)
self._original_execute_core = getattr(import_module("crewai").Task, "_execute_core", None)
wrap_function_wrapper(
module="crewai",
name="Task._execute_core",
wrapper=execute_core_wrapper,
)

kickoff_wrapper = _KickoffWrapper(tracer=tracer)
kickoff_wrapper = _KickoffWrapper(tracer=self._tracer)
self._original_kickoff = getattr(import_module("crewai").Crew, "kickoff", None)
wrap_function_wrapper(
module="crewai",
name="Crew.kickoff",
wrapper=kickoff_wrapper,
)

use_wrapper = _ToolUseWrapper(tracer=tracer)
use_wrapper = _ToolUseWrapper(tracer=self._tracer)
self._original_tool_use = getattr(
import_module("crewai.tools.tool_usage").ToolUsage, "_use", None
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from inspect import signature
from typing import Any, Callable, Iterator, List, Mapping, Optional, Tuple

from openinference.instrumentation import safe_json_dumps
from openinference.instrumentation import get_attributes_from_context, safe_json_dumps
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
from opentelemetry import context as context_api
from opentelemetry import trace as trace_api
from opentelemetry.util.types import AttributeValue

Expand Down Expand Up @@ -88,6 +89,8 @@ def __call__(
args: Tuple[Any, ...],
kwargs: Mapping[str, Any],
) -> Any:
if context_api.get_value(context_api._SUPPRESS_INSTRUMENTATION_KEY):
return wrapped(*args, **kwargs)
if instance:
span_name = f"{instance.__class__.__name__}.{wrapped.__name__}"
else:
Expand Down Expand Up @@ -130,6 +133,7 @@ def __call__(
raise
span.set_status(trace_api.StatusCode.OK)
span.set_attribute(OUTPUT_VALUE, response)
span.set_attributes(dict(get_attributes_from_context()))
return response


Expand All @@ -144,6 +148,8 @@ def __call__(
args: Tuple[Any, ...],
kwargs: Mapping[str, Any],
) -> Any:
if context_api.get_value(context_api._SUPPRESS_INSTRUMENTATION_KEY):
return wrapped(*args, **kwargs)
span_name = f"{instance.__class__.__name__}.kickoff"
with self._tracer.start_as_current_span(
span_name,
Expand Down Expand Up @@ -206,6 +212,7 @@ def __call__(
raise
span.set_status(trace_api.StatusCode.OK)
span.set_attribute(OUTPUT_VALUE, response)
span.set_attributes(dict(get_attributes_from_context()))
return response


Expand All @@ -220,6 +227,8 @@ def __call__(
args: Tuple[Any, ...],
kwargs: Mapping[str, Any],
) -> Any:
if context_api.get_value(context_api._SUPPRESS_INSTRUMENTATION_KEY):
return wrapped(*args, **kwargs)
if instance:
span_name = f"{instance.__class__.__name__}.{wrapped.__name__}"
else:
Expand Down Expand Up @@ -255,6 +264,7 @@ def __call__(
raise
span.set_status(trace_api.StatusCode.OK)
span.set_attribute(OUTPUT_VALUE, response)
span.set_attributes(dict(get_attributes_from_context()))
return response


Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from typing import Any, Generator
import json
from typing import Any, Dict, Generator, List, Mapping, cast

import pytest
from openinference.instrumentation import OITracer, using_attributes
from openinference.instrumentation.crewai import CrewAIInstrumentor
from openinference.semconv.trace import (
SpanAttributes,
)
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from opentelemetry.util.types import AttributeValue


@pytest.fixture()
Expand All @@ -30,9 +36,139 @@ def setup_crewai_instrumentation(
CrewAIInstrumentor().uninstrument()


# Ensure we're using the common OITracer from common opeinference-instrumentation pkg
def test_oitracer(
tracer_provider: TracerProvider,
in_memory_span_exporter: InMemorySpanExporter,
setup_crewai_instrumentation: Any,
) -> None:
in_memory_span_exporter.clear()
assert isinstance(CrewAIInstrumentor()._tracer, OITracer)


@pytest.mark.parametrize("use_context_attributes", [False, True])
def test_crewai_instrumentation(
tracer_provider: TracerProvider,
in_memory_span_exporter: InMemorySpanExporter,
setup_crewai_instrumentation: Any,
use_context_attributes: bool,
session_id: str,
user_id: str,
metadata: Dict[str, Any],
tags: List[str],
prompt_template: str,
prompt_template_version: str,
prompt_template_variables: Dict[str, Any],
) -> None:
pass
if use_context_attributes:
with using_attributes(
session_id=session_id,
user_id=user_id,
metadata=metadata,
tags=tags,
prompt_template=prompt_template,
prompt_template_version=prompt_template_version,
prompt_template_variables=prompt_template_variables,
):
return # For now, short-circuiting. Insert CrewAI function calls here
else:
return # For now, short-circuiting. Insert CrewAI function calls here

spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == 1
span = spans[0]
# Insert CrewAI testing logic here

# Check context attributes logic
attributes = dict(cast(Mapping[str, AttributeValue], span.attributes))
if use_context_attributes:
_check_context_attributes(
attributes,
session_id,
user_id,
metadata,
tags,
prompt_template,
prompt_template_version,
prompt_template_variables,
)


def _check_context_attributes(
attributes: Dict[str, Any],
session_id: str,
user_id: str,
metadata: Dict[str, Any],
tags: List[str],
prompt_template: str,
prompt_template_version: str,
prompt_template_variables: Dict[str, Any],
) -> None:
assert attributes.pop(SpanAttributes.SESSION_ID, None) == session_id
assert attributes.pop(SpanAttributes.USER_ID, None) == user_id
attr_metadata = attributes.pop(SpanAttributes.METADATA, None)
assert attr_metadata is not None
assert isinstance(attr_metadata, str) # must be json string
metadata_dict = json.loads(attr_metadata)
assert metadata_dict == metadata
attr_tags = attributes.pop(SpanAttributes.TAG_TAGS, None)
assert attr_tags is not None
assert len(attr_tags) == len(tags)
assert list(attr_tags) == tags
assert attributes.pop(SpanAttributes.LLM_PROMPT_TEMPLATE, None) == prompt_template
assert (
attributes.pop(SpanAttributes.LLM_PROMPT_TEMPLATE_VERSION, None) == prompt_template_version
)
assert attributes.pop(SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES, None) == json.dumps(
prompt_template_variables
)


@pytest.fixture()
def session_id() -> str:
return "my-test-session-id"


@pytest.fixture()
def user_id() -> str:
return "my-test-user-id"


@pytest.fixture()
def metadata() -> Dict[str, Any]:
return {
"test-int": 1,
"test-str": "string",
"test-list": [1, 2, 3],
"test-dict": {
"key-1": "val-1",
"key-2": "val-2",
},
}


@pytest.fixture()
def tags() -> List[str]:
return ["tag-1", "tag-2"]


@pytest.fixture
def prompt_template() -> str:
return (
"This is a test prompt template with int {var_int}, "
"string {var_string}, and list {var_list}"
)


@pytest.fixture
def prompt_template_version() -> str:
return "v1.0"


@pytest.fixture
def prompt_template_variables() -> Dict[str, Any]:
return {
"var_int": 1,
"var_str": "2",
"var_list": [1, 2, 3],
}

0 comments on commit 4ad22fa

Please sign in to comment.