diff --git a/changelog/1077.improvement.md b/changelog/1077.improvement.md new file mode 100644 index 000000000..bea999c4f --- /dev/null +++ b/changelog/1077.improvement.md @@ -0,0 +1 @@ +Instrument `ValidationAction._extract_validation_events` and `FormValidationAction._extract_validation_events` and extract `validated_events` and `slots` attributes. \ No newline at end of file diff --git a/rasa_sdk/tracing/config.py b/rasa_sdk/tracing/config.py index a0079a0b1..080f2eb7a 100644 --- a/rasa_sdk/tracing/config.py +++ b/rasa_sdk/tracing/config.py @@ -14,7 +14,7 @@ from rasa_sdk.tracing.endpoints import EndpointConfig, read_endpoint_config from rasa_sdk.tracing.instrumentation import instrumentation from rasa_sdk.executor import ActionExecutor -from rasa_sdk.forms import ValidationAction +from rasa_sdk.forms import ValidationAction, FormValidationAction TRACING_SERVICE_NAME = os.environ.get("RASA_SDK_TRACING_SERVICE_NAME", "rasa_sdk") @@ -39,6 +39,7 @@ def configure_tracing(tracer_provider: Optional[TracerProvider]) -> None: tracer_provider=tracer_provider, action_executor_class=ActionExecutor, validation_action_class=ValidationAction, + form_validation_action_class=FormValidationAction, ) diff --git a/rasa_sdk/tracing/instrumentation/attribute_extractors.py b/rasa_sdk/tracing/instrumentation/attribute_extractors.py index 432c04603..ff389438c 100644 --- a/rasa_sdk/tracing/instrumentation/attribute_extractors.py +++ b/rasa_sdk/tracing/instrumentation/attribute_extractors.py @@ -36,9 +36,9 @@ def extract_attrs_for_action_executor( def extract_attrs_for_validation_action( self: ValidationAction, - dispatcher: "CollectingDispatcher", - tracker: "Tracker", - domain: "DomainDict", + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, ) -> Dict[Text, Any]: """Extract the attributes for `ValidationAction.run`. diff --git a/rasa_sdk/tracing/instrumentation/instrumentation.py b/rasa_sdk/tracing/instrumentation/instrumentation.py index fe7918f12..47f896475 100644 --- a/rasa_sdk/tracing/instrumentation/instrumentation.py +++ b/rasa_sdk/tracing/instrumentation/instrumentation.py @@ -1,23 +1,30 @@ import functools import inspect +import json import logging from typing import ( Any, Awaitable, Callable, Dict, + List, Optional, Text, Type, TypeVar, + Union, ) from opentelemetry.sdk.trace import TracerProvider from opentelemetry.trace import Tracer -from rasa_sdk.executor import ActionExecutor -from rasa_sdk.forms import ValidationAction +from rasa_sdk import Tracker +from rasa_sdk.executor import ActionExecutor, CollectingDispatcher +from rasa_sdk.events import EventType +from rasa_sdk.forms import ValidationAction, FormValidationAction from rasa_sdk.tracing.instrumentation import attribute_extractors from rasa_sdk.tracing.tracer_register import ActionExecutorTracerRegister +from rasa_sdk.types import DomainDict + # The `TypeVar` representing the return type for a function to be wrapped. S = TypeVar("S") @@ -72,7 +79,9 @@ async def async_wrapper(self: T, *args: Any, **kwargs: Any) -> S: if attr_extractor and should_extract_args else {} ) - if issubclass(self.__class__, ValidationAction): + if issubclass(self.__class__, FormValidationAction): + span_name = f"FormValidationAction.{self.__class__.__name__}.{fn.__name__}" + elif issubclass(self.__class__, ValidationAction): span_name = f"ValidationAction.{self.__class__.__name__}.{fn.__name__}" else: span_name = f"{self.__class__.__name__}.{fn.__name__}" @@ -128,12 +137,16 @@ def wrapper(self: T, *args: Any, **kwargs: Any) -> S: ActionExecutorType = TypeVar("ActionExecutorType", bound=ActionExecutor) ValidationActionType = TypeVar("ValidationActionType", bound=ValidationAction) +FormValidationActionType = TypeVar( + "FormValidationActionType", bound=FormValidationAction +) def instrument( tracer_provider: TracerProvider, action_executor_class: Optional[Type[ActionExecutorType]] = None, validation_action_class: Optional[Type[ValidationActionType]] = None, + form_validation_action_class: Optional[Type[FormValidationActionType]] = None, ) -> None: """Substitute methods to be traced by their traced counterparts. @@ -143,6 +156,8 @@ def instrument( is given, no `ActionExecutor` will be instrumented. :param validation_action_class: The `ValidationAction` to be instrumented. If `None` is given, no `ValidationAction` will be instrumented. + :param form_validation_action_class: The `FormValidationAction` to be instrumented. + If `None` is given, no `FormValidationAction` will be instrumented. """ if action_executor_class is not None and not class_is_instrumented( action_executor_class @@ -172,8 +187,21 @@ def instrument( "run", attribute_extractors.extract_attrs_for_validation_action, ) + _instrument_validation_action_extract_validation_events( + tracer_provider.get_tracer(validation_action_class.__module__), + validation_action_class, + ) mark_class_as_instrumented(validation_action_class) + if form_validation_action_class is not None and not class_is_instrumented( + form_validation_action_class + ): + _instrument_validation_action_extract_validation_events( + tracer_provider.get_tracer(form_validation_action_class.__module__), + form_validation_action_class, + ) + mark_class_as_instrumented(form_validation_action_class) + def _instrument_method( tracer: Tracer, @@ -214,3 +242,58 @@ def mark_class_as_instrumented(instrumented_class: Type) -> None: _mangled_instrumented_boolean_attribute_name(instrumented_class), True, ) + + +def _instrument_validation_action_extract_validation_events( + tracer: Tracer, + validation_action_class: Union[Type[ValidationAction], Type[FormValidationAction]], +) -> None: + def tracing_validation_action_extract_validation_events_wrapper( + fn: Callable, + ) -> Callable: + @functools.wraps(fn) + async def wrapper( + self, + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, + ) -> List[EventType]: + if issubclass(self.__class__, FormValidationAction): + span_name = ( + f"FormValidationAction.{self.__class__.__name__}.{fn.__name__}" + ) + else: + span_name = f"ValidationAction.{self.__class__.__name__}.{fn.__name__}" + + with tracer.start_as_current_span(span_name) as span: + validation_events = await fn(self, dispatcher, tracker, domain) + event_names = [] + slot_names = [] + + if validation_events: + for event in validation_events: + event_names.append(event.get("event")) + if event.get("event") == "slot": + slot_names.append(event.get("name")) + + span.set_attributes( + { + "validation_events": json.dumps( + list(dict.fromkeys(event_names)) + ), + "slots": json.dumps(list(dict.fromkeys(slot_names))), + } + ) + return validation_events + + return wrapper + + validation_action_class._extract_validation_events = ( # type: ignore + tracing_validation_action_extract_validation_events_wrapper( + validation_action_class._extract_validation_events + ) + ) + + logger.debug( + f"Instrumented '{validation_action_class.__name__}._extract_validation_events'." + ) diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index 29678d3a9..06e779e80 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -23,7 +23,7 @@ def test_server_health_returns_200(): def test_server_list_actions_returns_200(): request, response = app.test_client.get("/actions") assert response.status == 200 - assert len(response.json) == 4 + assert len(response.json) == 5 # ENSURE TO UPDATE AS MORE ACTIONS ARE ADDED IN OTHER TESTS expected = [ @@ -33,6 +33,7 @@ def test_server_list_actions_returns_200(): {"name": "custom_action_exception"}, # defined in tests/tracing/instrumentation/conftest.py {"name": "mock_validation_action"}, + {"name": "mock_form_validation_action"}, ] assert response.json == expected diff --git a/tests/tracing/instrumentation/conftest.py b/tests/tracing/instrumentation/conftest.py index 796113063..5b7570f4b 100644 --- a/tests/tracing/instrumentation/conftest.py +++ b/tests/tracing/instrumentation/conftest.py @@ -6,7 +6,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from rasa_sdk.executor import ActionExecutor, CollectingDispatcher -from rasa_sdk.forms import ValidationAction +from rasa_sdk.forms import ValidationAction, FormValidationAction from rasa_sdk.types import ActionCall, DomainDict from rasa_sdk import Tracker @@ -79,3 +79,38 @@ async def run( def name(self) -> Text: return "mock_validation_action" + + async def _extract_validation_events( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> None: + return tracker.events + + +class MockFormValidationAction(FormValidationAction): + def __init__(self) -> None: + self.fail_if_undefined("run") + + def fail_if_undefined(self, method_name: Text) -> None: + if not ( + hasattr(self.__class__.__base__, method_name) + and callable(getattr(self.__class__.__base__, method_name)) + ): + pytest.fail( + f"method '{method_name}' not found in {self.__class__.__base__}. " + f"This likely means the method was renamed, which means the " + f"instrumentation needs to be adapted!" + ) + + async def _extract_validation_events( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> None: + return tracker.events + + def name(self) -> Text: + return "mock_form_validation_action" diff --git a/tests/tracing/instrumentation/test_form_validation_action.py b/tests/tracing/instrumentation/test_form_validation_action.py new file mode 100644 index 000000000..66374c6e3 --- /dev/null +++ b/tests/tracing/instrumentation/test_form_validation_action.py @@ -0,0 +1,128 @@ +from typing import Sequence, Optional + +import pytest +from opentelemetry.sdk.trace import ReadableSpan, TracerProvider +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +from rasa_sdk.tracing.instrumentation import instrumentation +from tests.tracing.instrumentation.conftest import MockFormValidationAction +from rasa_sdk import Tracker +from rasa_sdk.executor import CollectingDispatcher +from rasa_sdk.events import ActionExecuted, SlotSet + + +@pytest.mark.parametrize( + "events, expected_slots_to_validate", + [ + ([], "[]"), + ([ActionExecuted("my_form")], "[]"), + ( + [SlotSet("name", "Tom"), SlotSet("address", "Berlin")], + '["name", "address"]', + ), + ], +) +@pytest.mark.asyncio +async def test_form_validation_action_run( + tracer_provider: TracerProvider, + span_exporter: InMemorySpanExporter, + previous_num_captured_spans: int, + events: Optional[str], + expected_slots_to_validate: Optional[str], +) -> None: + component_class = MockFormValidationAction + + instrumentation.instrument( + tracer_provider, + validation_action_class=component_class, + ) + + mock_validation_action = component_class() + dispatcher = CollectingDispatcher() + tracker = Tracker.from_dict({"sender_id": "test", "events": events}) + + await mock_validation_action.run(dispatcher, tracker, {}) + + captured_spans: Sequence[ + ReadableSpan + ] = span_exporter.get_finished_spans() # type: ignore + + num_captured_spans = len(captured_spans) - previous_num_captured_spans + # includes the child span for `_extract_validation_events` method call + assert num_captured_spans == 2 + + captured_span = captured_spans[-1] + + assert captured_span.name == "FormValidationAction.MockFormValidationAction.run" + + expected_attributes = { + "class_name": component_class.__name__, + "sender_id": "test", + "slots_to_validate": expected_slots_to_validate, + "action_name": "mock_form_validation_action", + } + + assert captured_span.attributes == expected_attributes + + +@pytest.mark.parametrize( + "events, slots, validation_events", + [ + ([], "[]", "[]"), + ([ActionExecuted("my_form")], "[]", '["action"]'), + ( + [SlotSet("name", "Tom")], + '["name"]', + '["slot"]', + ), + ( + [SlotSet("name", "Tom"), SlotSet("address", "Berlin")], + '["name", "address"]', + '["slot"]', + ), + ], +) +@pytest.mark.asyncio +async def test_form_validation_action_extract_validation_events( + tracer_provider: TracerProvider, + span_exporter: InMemorySpanExporter, + previous_num_captured_spans: int, + events: Optional[str], + slots: Optional[str], + validation_events: Optional[str], +) -> None: + component_class = MockFormValidationAction + + instrumentation.instrument( + tracer_provider, + form_validation_action_class=component_class, + ) + + mock_form_validation_action = component_class() + dispatcher = CollectingDispatcher() + tracker = Tracker.from_dict({"sender_id": "test", "events": events}) + + await mock_form_validation_action._extract_validation_events( + dispatcher, tracker, {} + ) + + captured_spans: Sequence[ + ReadableSpan + ] = span_exporter.get_finished_spans() # type: ignore + + num_captured_spans = len(captured_spans) - previous_num_captured_spans + assert num_captured_spans == 1 + + captured_span = captured_spans[-1] + expected_span_name = ( + "FormValidationAction.MockFormValidationAction._extract_validation_events" + ) + + assert captured_span.name == expected_span_name + + expected_attributes = { + "validation_events": validation_events, + "slots": slots, + } + + assert captured_span.attributes == expected_attributes diff --git a/tests/tracing/instrumentation/test_validation_action.py b/tests/tracing/instrumentation/test_validation_action.py index ce407d8db..860d5032d 100644 --- a/tests/tracing/instrumentation/test_validation_action.py +++ b/tests/tracing/instrumentation/test_validation_action.py @@ -1,20 +1,23 @@ -from typing import List, Sequence +from typing import List, Optional, Sequence import pytest from opentelemetry.sdk.trace import ReadableSpan, TracerProvider from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from rasa_sdk.tracing.instrumentation import instrumentation -from tests.tracing.instrumentation.conftest import MockValidationAction +from tests.tracing.instrumentation.conftest import ( + MockValidationAction, +) from rasa_sdk import Tracker from rasa_sdk.executor import CollectingDispatcher -from rasa_sdk.events import SlotSet, EventType +from rasa_sdk.events import SlotSet, EventType, ActionExecuted @pytest.mark.parametrize( "events, expected_slots_to_validate", [ ([], "[]"), + ([ActionExecuted("my_form")], "[]"), ( [SlotSet("name", "Tom"), SlotSet("address", "Berlin")], '["name", "address"]', @@ -22,7 +25,7 @@ ], ) @pytest.mark.asyncio -async def test_tracing_action_executor_run( +async def test_validation_action_run( tracer_provider: TracerProvider, span_exporter: InMemorySpanExporter, previous_num_captured_spans: int, @@ -61,3 +64,66 @@ async def test_tracing_action_executor_run( } assert captured_span.attributes == expected_attributes + + +@pytest.mark.parametrize( + "events, slots, validation_events", + [ + ([], "[]", "[]"), + ([ActionExecuted("my_form")], "[]", '["action"]'), + ( + [SlotSet("name", "Tom")], + '["name"]', + '["slot"]', + ), + ( + [SlotSet("name", "Tom"), SlotSet("address", "Berlin")], + '["name", "address"]', + '["slot"]', + ), + ], +) +@pytest.mark.asyncio +async def test_validation_action_extract_validation_events( + tracer_provider: TracerProvider, + span_exporter: InMemorySpanExporter, + previous_num_captured_spans: int, + events: Optional[str], + slots: Optional[str], + validation_events: Optional[str], +) -> None: + component_class = MockValidationAction + + instrumentation.instrument( + tracer_provider, + form_validation_action_class=component_class, + ) + + mock_form_validation_action = component_class() + dispatcher = CollectingDispatcher() + tracker = Tracker.from_dict({"sender_id": "test", "events": events}) + + await mock_form_validation_action._extract_validation_events( + dispatcher, tracker, {} + ) + + captured_spans: Sequence[ + ReadableSpan + ] = span_exporter.get_finished_spans() # type: ignore + + num_captured_spans = len(captured_spans) - previous_num_captured_spans + assert num_captured_spans == 1 + + captured_span = captured_spans[-1] + expected_span_name = ( + "ValidationAction.MockValidationAction._extract_validation_events" + ) + + assert captured_span.name == expected_span_name + + expected_attributes = { + "validation_events": validation_events, + "slots": slots, + } + + assert captured_span.attributes == expected_attributes