Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ATO-2103] Instrument FormValidationAction._extract_validation_events #1077

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/1077.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Instrument `ValidationAction._extract_validation_events` and `FormValidationAction._extract_validation_events` and extract `validated_events` and `slots` attributes.
3 changes: 2 additions & 1 deletion rasa_sdk/tracing/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from rasa_sdk.tracing.endpoints import EndpointConfig, read_endpoint_config
from rasa_sdk.tracing.instrumentation import instrumentation
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.forms import ValidationAction
from rasa_sdk.forms import ValidationAction, FormValidationAction

TRACING_SERVICE_NAME = os.environ.get("RASA_SDK_TRACING_SERVICE_NAME", "rasa_sdk")

Expand All @@ -39,6 +39,7 @@ def configure_tracing(tracer_provider: Optional[TracerProvider]) -> None:
tracer_provider=tracer_provider,
action_executor_class=ActionExecutor,
validation_action_class=ValidationAction,
form_validation_action_class=FormValidationAction,
)


Expand Down
6 changes: 3 additions & 3 deletions rasa_sdk/tracing/instrumentation/attribute_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def extract_attrs_for_action_executor(

def extract_attrs_for_validation_action(
self: ValidationAction,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: DomainDict,
) -> Dict[Text, Any]:
"""Extract the attributes for `ValidationAction.run`.

Expand Down
89 changes: 86 additions & 3 deletions rasa_sdk/tracing/instrumentation/instrumentation.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
import functools
import inspect
import json
import logging
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Text,
Type,
TypeVar,
Union,
)

from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace import Tracer
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.forms import ValidationAction
from rasa_sdk import Tracker
from rasa_sdk.executor import ActionExecutor, CollectingDispatcher
from rasa_sdk.events import EventType
from rasa_sdk.forms import ValidationAction, FormValidationAction
from rasa_sdk.tracing.instrumentation import attribute_extractors
from rasa_sdk.tracing.tracer_register import ActionExecutorTracerRegister
from rasa_sdk.types import DomainDict


# The `TypeVar` representing the return type for a function to be wrapped.
S = TypeVar("S")
Expand Down Expand Up @@ -72,7 +79,9 @@ async def async_wrapper(self: T, *args: Any, **kwargs: Any) -> S:
if attr_extractor and should_extract_args
else {}
)
if issubclass(self.__class__, ValidationAction):
if issubclass(self.__class__, FormValidationAction):
span_name = f"FormValidationAction.{self.__class__.__name__}.{fn.__name__}"
elif issubclass(self.__class__, ValidationAction):
span_name = f"ValidationAction.{self.__class__.__name__}.{fn.__name__}"
else:
span_name = f"{self.__class__.__name__}.{fn.__name__}"
Expand Down Expand Up @@ -128,12 +137,16 @@ def wrapper(self: T, *args: Any, **kwargs: Any) -> S:

ActionExecutorType = TypeVar("ActionExecutorType", bound=ActionExecutor)
ValidationActionType = TypeVar("ValidationActionType", bound=ValidationAction)
FormValidationActionType = TypeVar(
"FormValidationActionType", bound=FormValidationAction
)


def instrument(
tracer_provider: TracerProvider,
action_executor_class: Optional[Type[ActionExecutorType]] = None,
validation_action_class: Optional[Type[ValidationActionType]] = None,
form_validation_action_class: Optional[Type[FormValidationActionType]] = None,
) -> None:
"""Substitute methods to be traced by their traced counterparts.

Expand All @@ -143,6 +156,8 @@ def instrument(
is given, no `ActionExecutor` will be instrumented.
:param validation_action_class: The `ValidationAction` to be instrumented. If `None`
is given, no `ValidationAction` will be instrumented.
:param form_validation_action_class: The `FormValidationAction` to be instrumented.
If `None` is given, no `FormValidationAction` will be instrumented.
"""
if action_executor_class is not None and not class_is_instrumented(
action_executor_class
Expand Down Expand Up @@ -172,8 +187,21 @@ def instrument(
"run",
attribute_extractors.extract_attrs_for_validation_action,
)
_instrument_validation_action_extract_validation_events(
tracer_provider.get_tracer(validation_action_class.__module__),
validation_action_class,
)
mark_class_as_instrumented(validation_action_class)

if form_validation_action_class is not None and not class_is_instrumented(
form_validation_action_class
):
_instrument_validation_action_extract_validation_events(
tracer_provider.get_tracer(form_validation_action_class.__module__),
form_validation_action_class,
)
mark_class_as_instrumented(form_validation_action_class)


def _instrument_method(
tracer: Tracer,
Expand Down Expand Up @@ -214,3 +242,58 @@ def mark_class_as_instrumented(instrumented_class: Type) -> None:
_mangled_instrumented_boolean_attribute_name(instrumented_class),
True,
)


def _instrument_validation_action_extract_validation_events(
tracer: Tracer,
validation_action_class: Union[Type[ValidationAction], Type[FormValidationAction]],
) -> None:
def tracing_validation_action_extract_validation_events_wrapper(
fn: Callable,
) -> Callable:
@functools.wraps(fn)
async def wrapper(
self,
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: DomainDict,
) -> List[EventType]:
if issubclass(self.__class__, FormValidationAction):
span_name = (
f"FormValidationAction.{self.__class__.__name__}.{fn.__name__}"
)
else:
span_name = f"ValidationAction.{self.__class__.__name__}.{fn.__name__}"

with tracer.start_as_current_span(span_name) as span:
validation_events = await fn(self, dispatcher, tracker, domain)
event_names = []
slot_names = []

if validation_events:
for event in validation_events:
event_names.append(event.get("event"))
if event.get("event") == "slot":
slot_names.append(event.get("name"))

span.set_attributes(
{
"validation_events": json.dumps(
list(dict.fromkeys(event_names))
),
"slots": json.dumps(list(dict.fromkeys(slot_names))),
}
)
return validation_events

return wrapper

validation_action_class._extract_validation_events = ( # type: ignore
tracing_validation_action_extract_validation_events_wrapper(
validation_action_class._extract_validation_events
)
)

logger.debug(
f"Instrumented '{validation_action_class.__name__}._extract_validation_events'."
)
3 changes: 2 additions & 1 deletion tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_server_health_returns_200():
def test_server_list_actions_returns_200():
request, response = app.test_client.get("/actions")
assert response.status == 200
assert len(response.json) == 4
assert len(response.json) == 5

# ENSURE TO UPDATE AS MORE ACTIONS ARE ADDED IN OTHER TESTS
expected = [
Expand All @@ -33,6 +33,7 @@ def test_server_list_actions_returns_200():
{"name": "custom_action_exception"},
# defined in tests/tracing/instrumentation/conftest.py
{"name": "mock_validation_action"},
{"name": "mock_form_validation_action"},
]
assert response.json == expected

Expand Down
37 changes: 36 additions & 1 deletion tests/tracing/instrumentation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from rasa_sdk.executor import ActionExecutor, CollectingDispatcher
from rasa_sdk.forms import ValidationAction
from rasa_sdk.forms import ValidationAction, FormValidationAction
from rasa_sdk.types import ActionCall, DomainDict
from rasa_sdk import Tracker

Expand Down Expand Up @@ -79,3 +79,38 @@ async def run(

def name(self) -> Text:
return "mock_validation_action"

async def _extract_validation_events(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> None:
return tracker.events


class MockFormValidationAction(FormValidationAction):
def __init__(self) -> None:
self.fail_if_undefined("run")

def fail_if_undefined(self, method_name: Text) -> None:
if not (
hasattr(self.__class__.__base__, method_name)
and callable(getattr(self.__class__.__base__, method_name))
):
pytest.fail(
f"method '{method_name}' not found in {self.__class__.__base__}. "
f"This likely means the method was renamed, which means the "
f"instrumentation needs to be adapted!"
)

async def _extract_validation_events(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> None:
return tracker.events

def name(self) -> Text:
return "mock_form_validation_action"
128 changes: 128 additions & 0 deletions tests/tracing/instrumentation/test_form_validation_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from typing import Sequence, Optional

import pytest
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from rasa_sdk.tracing.instrumentation import instrumentation
from tests.tracing.instrumentation.conftest import MockFormValidationAction
from rasa_sdk import Tracker
from rasa_sdk.executor import CollectingDispatcher
from rasa_sdk.events import ActionExecuted, SlotSet


@pytest.mark.parametrize(
"events, expected_slots_to_validate",
[
([], "[]"),
([ActionExecuted("my_form")], "[]"),
(
[SlotSet("name", "Tom"), SlotSet("address", "Berlin")],
ancalita marked this conversation as resolved.
Show resolved Hide resolved
'["name", "address"]',
),
],
)
@pytest.mark.asyncio
async def test_form_validation_action_run(
tracer_provider: TracerProvider,
span_exporter: InMemorySpanExporter,
previous_num_captured_spans: int,
events: Optional[str],
expected_slots_to_validate: Optional[str],
) -> None:
component_class = MockFormValidationAction

instrumentation.instrument(
tracer_provider,
validation_action_class=component_class,
)

mock_validation_action = component_class()
dispatcher = CollectingDispatcher()
tracker = Tracker.from_dict({"sender_id": "test", "events": events})

await mock_validation_action.run(dispatcher, tracker, {})

captured_spans: Sequence[
ReadableSpan
] = span_exporter.get_finished_spans() # type: ignore

num_captured_spans = len(captured_spans) - previous_num_captured_spans
# includes the child span for `_extract_validation_events` method call
assert num_captured_spans == 2

captured_span = captured_spans[-1]

assert captured_span.name == "FormValidationAction.MockFormValidationAction.run"

expected_attributes = {
"class_name": component_class.__name__,
"sender_id": "test",
"slots_to_validate": expected_slots_to_validate,
"action_name": "mock_form_validation_action",
}

assert captured_span.attributes == expected_attributes


@pytest.mark.parametrize(
"events, slots, validation_events",
[
([], "[]", "[]"),
([ActionExecuted("my_form")], "[]", '["action"]'),
(
[SlotSet("name", "Tom")],
'["name"]',
'["slot"]',
),
(
[SlotSet("name", "Tom"), SlotSet("address", "Berlin")],
'["name", "address"]',
'["slot"]',
),
],
)
@pytest.mark.asyncio
async def test_form_validation_action_extract_validation_events(
tracer_provider: TracerProvider,
span_exporter: InMemorySpanExporter,
previous_num_captured_spans: int,
events: Optional[str],
slots: Optional[str],
validation_events: Optional[str],
) -> None:
component_class = MockFormValidationAction

instrumentation.instrument(
tracer_provider,
form_validation_action_class=component_class,
)

mock_form_validation_action = component_class()
dispatcher = CollectingDispatcher()
tracker = Tracker.from_dict({"sender_id": "test", "events": events})

await mock_form_validation_action._extract_validation_events(
dispatcher, tracker, {}
)

captured_spans: Sequence[
ReadableSpan
] = span_exporter.get_finished_spans() # type: ignore

num_captured_spans = len(captured_spans) - previous_num_captured_spans
assert num_captured_spans == 1

captured_span = captured_spans[-1]
expected_span_name = (
"FormValidationAction.MockFormValidationAction._extract_validation_events"
)

assert captured_span.name == expected_span_name

expected_attributes = {
"validation_events": validation_events,
"slots": slots,
}

assert captured_span.attributes == expected_attributes
Loading
Loading