diff --git a/rasa_addons/core/actions/action_botfront_form.py b/rasa_addons/core/actions/action_botfront_form.py index 85aee491bc..eec45095b9 100644 --- a/rasa_addons/core/actions/action_botfront_form.py +++ b/rasa_addons/core/actions/action_botfront_form.py @@ -1,6 +1,11 @@ import logging import functools -from typing import Dict, Text, Any, List, Union, Optional, Tuple +from typing import Dict, Text, Any, List, Optional +from rasa.core.slots import Slot + +from rasa_addons.core.actions.required_slots_graph_parser import ( + RequiredSlotsGraphParser, +) from rasa_addons.core.actions.slot_rule_validator import validate_with_rule from rasa_addons.core.actions.submit_form_to_botfront import submit_form_to_botfront @@ -41,7 +46,12 @@ def __str__(self) -> Text: return f"FormAction('{self.name()}')" def required_slots(self, tracker): - return [s.get("name") for s in self.form_spec.get("slots", [])] + graph = self.form_spec.get("graph_elements") + if not graph: + return [s.get("name") for s in self.form_spec.get("slots", [])] + parser = RequiredSlotsGraphParser(graph) + required_slots = parser.get_required_slots(tracker) + return required_slots def get_field_for_slot( self, slot: Text, field: Text, default: Optional[Any] = None, @@ -51,32 +61,6 @@ def get_field_for_slot( return s.get(field, default) return default - async def validate_prefilled( - self, - output_channel: "OutputChannel", - nlg: "NaturalLanguageGenerator", - tracker: "DialogueStateTracker", - domain: "Domain", - ): - # collect values of required slots filled before activation - prefilled_slots = {} - events = [] - - for slot_name in self.required_slots(tracker): - if not self._should_request_slot(tracker, slot_name): - prefilled_slots[slot_name] = tracker.get_slot(slot_name) - - if prefilled_slots: - logger.debug(f"Validating pre-filled required slots: {prefilled_slots}") - events.extend( - await self.validate_slots( - prefilled_slots, output_channel, nlg, tracker, domain - ) - ) - else: - logger.debug("No pre-filled required slots to validate.") - return events - async def run( self, output_channel: "OutputChannel", @@ -86,7 +70,9 @@ async def run( ) -> List[Event]: # attempt retrieving spec if not len(self.form_spec): - for form in tracker.slots.get("bf_forms", {}).initial_value: + for form in tracker.slots.get( + "bf_forms", Slot("bf_forms", initial_value=[]) + ).initial_value: if form.get("name") == self.name(): self.form_spec = clean_none_values(form) if not len(self.form_spec): @@ -150,7 +136,8 @@ async def submit( template = await nlg.generate( f"utter_submit_{self.name()}", tracker, output_channel.name(), ) - events += [create_bot_utterance(template)] + if template: + events += [create_bot_utterance(template)] if collect_in_botfront: submit_form_to_botfront(tracker) return events @@ -222,11 +209,13 @@ def entity_is_desired( @staticmethod def get_entity_value( - name: Text, + name: Optional[Text], tracker: "DialogueStateTracker", role: Optional[Text] = None, group: Optional[Text] = None, ) -> Any: + if not name: + return None # list is used to cover the case of list slot type value = list( tracker.get_latest_entity_values(name, entity_group=group, entity_role=role) @@ -246,6 +235,8 @@ def extract_other_slots( domain: "Domain", ) -> Dict[Text, Any]: slot_to_fill = tracker.get_slot(REQUESTED_SLOT) + if not slot_to_fill: + return {} slot_values = {} for slot in self.required_slots(tracker): @@ -300,6 +291,8 @@ def extract_requested_slot( else return None """ slot_to_fill = tracker.get_slot(REQUESTED_SLOT) + if not slot_to_fill: + return {} logger.debug(f"Trying to extract requested slot '{slot_to_fill}' ...") # get mapping for requested slot @@ -352,14 +345,16 @@ async def utter_post_validation( and self.get_field_for_slot(slot, "utter_on_new_valid_slot", False) is False ): return [] - valid = "valid" if valid else "invalid" + utter_what = "valid" if valid else "invalid" # so utter_(in)valid_slot supports {slot} template replacements temp_tracker = tracker.copy() temp_tracker.slots[slot].value = value template = await nlg.generate( - f"utter_{valid}_{slot}", temp_tracker, output_channel.name(), + f"utter_{utter_what}_{slot}", temp_tracker, output_channel.name(), ) + if not template: + return [] return [create_bot_utterance(template)] async def validate_slots( diff --git a/rasa_addons/core/actions/required_slots_graph_parser.py b/rasa_addons/core/actions/required_slots_graph_parser.py new file mode 100644 index 0000000000..6aab7914d7 --- /dev/null +++ b/rasa_addons/core/actions/required_slots_graph_parser.py @@ -0,0 +1,56 @@ +from typing import Dict, Text, Any +from rasa_addons.core.actions.slot_rule_validator import validate_with_rule + + +class RequiredSlotsGraphParser: + def __init__(self, required_slots_graph: Dict[Text, Any]) -> None: + self.start = None + self.nodes = {} + for node in required_slots_graph.get("nodes", []): + if node.get("type") == "start": + self.start = node.get("id") + continue + self.nodes[node.get("id")] = node.get("slotName") + self.edges = {} + for edge in required_slots_graph.get("edges", []): + source = edge.get("source") + self.edges[source] = [*self.edges.get(source, []), edge] + + def get_required_slots(self, tracker, start=None): + required_slots = [] + current_source = start or self.start + current_edges = self.edges.get(current_source, []) + for edge in sorted(current_edges, key=lambda e: e.get("condition") is None): + target, condition = edge.get("target"), edge.get("condition") + if self.check_condition(tracker, condition): + required_slots.append(self.nodes.get(target)) + required_slots += self.get_required_slots(tracker, start=target) + break # use first matching condition, that's it + else: + continue + return required_slots + + def check_condition(self, tracker, condition): + if condition is None: + return True + props = condition.get("properties", {}) + children = condition.get("children1", {}).values() + if condition.get("type") == "rule": + return self.check_atomic_condition(tracker, **props) + conjunction_operator = any if props.get("conjunction") == "OR" else all + polarity = (lambda p: not p) if props.get("not") else (lambda p: p) + return polarity( + conjunction_operator( + self.check_condition(tracker, child) for child in children + ) + ) + + def check_atomic_condition(self, tracker, field, operator, value, **rest): + slot = tracker.slots.get(field) + return validate_with_rule( + slot.value if slot else None, + { + "operator": operator, + "comparatum": [*value, None][0] # value is always a singleton list + }, + ) diff --git a/rasa_addons/core/actions/slot_rule_validator.py b/rasa_addons/core/actions/slot_rule_validator.py index 51c8e85704..0f78a488eb 100644 --- a/rasa_addons/core/actions/slot_rule_validator.py +++ b/rasa_addons/core/actions/slot_rule_validator.py @@ -75,11 +75,11 @@ def validate_with_rule(value, validation_rule) -> bool: if operator in NUM_COMPARATUM_OPERATORS: try: comparatum = float(comparatum) - except ValueError: + except (ValueError, TypeError): raise ValueError( f"Validation operator '{operator}' requires a numerical comparatum." ) - except ValueError as e: + except (ValueError, TypeError) as e: logger.error(str(e)) return False if operator in TEXT_VALUE_OPERATORS and not isinstance(value, str): @@ -87,7 +87,7 @@ def validate_with_rule(value, validation_rule) -> bool: if operator in NUM_VALUE_OPERATORS: try: value = float(value) - except ValueError: + except (ValueError, TypeError): return False if operator == "is_in": return value in comparatum diff --git a/tests/addons/core/test_action_botfront_form.py b/tests/addons/core/test_action_botfront_form.py index d8ed5a9ea4..2372cdb46c 100644 --- a/tests/addons/core/test_action_botfront_form.py +++ b/tests/addons/core/test_action_botfront_form.py @@ -15,23 +15,110 @@ import pytest nlg = BotfrontTemplatedNaturalLanguageGenerator() -def new_form_and_tracker(form_spec, requested_slot): + + +def new_form_and_tracker(form_spec, requested_slot, additional_slots=[]): form = ActionBotfrontForm(form_spec.get("form_name")) tracker = DialogueStateTracker.from_dict( "default", [], [ Slot(name=requested_slot), + *[Slot(name=name) for name in additional_slots], Slot(name="requested_slot", initial_value=requested_slot), - Slot( - name="bf_forms", - initial_value=[form_spec] - ) - ] + Slot(name="bf_forms", initial_value=[form_spec]), + ], ) - form.form_spec = form_spec # load spec manually + form.form_spec = form_spec # load spec manually return form, tracker + +def required_slots_graph(conjunction="OR", negated=False): + return { + "nodes": [ + {"id": "0", "type": "start"}, + {"id": "1", "type": "slot", "slotName": "age"}, + {"id": "2", "type": "slot", "slotName": "authorization"}, + {"id": "3", "type": "slot", "slotName": "comments"}, + ], + "edges": [ + { + "id": "a", + "type": "condition", + "source": "0", + "target": "1", + "condition": None, + }, + { + "id": "d", + "type": "condition", + "source": "1", + "target": "3", + "condition": None, + }, + { + "id": "b", + "type": "condition", + "source": "1", + "target": "2", + "condition": { + "type": "group", + "id": "9a99988a-0123-4456-b89a-b1607f326fd8", + "children1": { + "a98ab9b9-cdef-4012-b456-71607f326fd9": { + "type": "rule", + "properties": { + "field": "age", + "operator": "lt", + "value": ["18"], + "valueSrc": ["value"], + "valueType": ["text"], + "valueError": [None], + }, + }, + "98a8a9ba-0123-4456-b89a-b16e721c8cd0": { + "type": "rule", + "properties": { + "field": "age", + "operator": "gt", + "value": ["65"], + "valueSrc": ["value"], + "valueType": ["text"], + "valueError": [None], + }, + }, + }, + "properties": {"conjunction": conjunction, "not": negated}, + }, + }, + { + "id": "c", + "type": "condition", + "source": "2", + "target": "3", + "condition": { + "type": "group", + "id": "9a99988a-0123-4456-b89a-b1607f326fd8", + "children1": { + "a98ab9b9-cdef-4012-b456-71607f326fd9": { + "type": "rule", + "properties": { + "field": "authorization", + "operator": "is_exactly", + "value": ["true"], + "valueSrc": ["value"], + "valueType": ["text"], + "valueError": [None], + }, + } + }, + "properties": {"conjunction": "OR", "not": None}, + }, + }, + ], + } + + def test_extract_requested_slot_default(): """Test default extraction of a slot value from entity with the same name """ @@ -39,11 +126,16 @@ def test_extract_requested_slot_default(): spec = {"name": "default_form"} form, tracker = new_form_and_tracker(spec, "some_slot") - tracker.update(UserUttered(entities=[{"entity": "some_slot", "value": "some_value"}])) + tracker.update( + UserUttered(entities=[{"entity": "some_slot", "value": "some_value"}]) + ) - slot_values = form.extract_requested_slot(OutputChannel(), nlg, tracker, Domain.empty()) + slot_values = form.extract_requested_slot( + OutputChannel(), nlg, tracker, Domain.empty() + ) assert slot_values == {"some_slot": "some_value"} + def test_extract_requested_slot_from_entity_no_intent(): """Test extraction of a slot value from entity with the different name and any intent @@ -51,21 +143,25 @@ def test_extract_requested_slot_from_entity_no_intent(): spec = { "name": "default_form", - "slots": [{ - "name": "some_slot", - "filling": [{ - "type": "from_entity", - "entity": ["some_entity"] - }] - }] + "slots": [ + { + "name": "some_slot", + "filling": [{"type": "from_entity", "entity": ["some_entity"]}], + } + ], } form, tracker = new_form_and_tracker(spec, "some_slot") - tracker.update(UserUttered(entities=[{"entity": "some_entity", "value": "some_value"}])) + tracker.update( + UserUttered(entities=[{"entity": "some_entity", "value": "some_value"}]) + ) - slot_values = form.extract_requested_slot(OutputChannel(), nlg, tracker, Domain.empty()) + slot_values = form.extract_requested_slot( + OutputChannel(), nlg, tracker, Domain.empty() + ) assert slot_values == {"some_slot": "some_value"} + def test_extract_requested_slot_from_entity_with_intent(): """Test extraction of a slot value from entity with the different name and certain intent @@ -73,98 +169,121 @@ def test_extract_requested_slot_from_entity_with_intent(): spec = { "name": "default_form", - "slots": [{ - "name": "some_slot", - "filling": [{ - "type": "from_entity", - "entity": ["some_entity"], - "intent": ["some_intent"] - }] - }] + "slots": [ + { + "name": "some_slot", + "filling": [ + { + "type": "from_entity", + "entity": ["some_entity"], + "intent": ["some_intent"], + } + ], + } + ], } form, tracker = new_form_and_tracker(spec, "some_slot") - tracker.update(UserUttered( - intent={"name": "some_intent", "confidence": 1.0}, - entities=[{"entity": "some_entity", "value": "some_value"}] - )) + tracker.update( + UserUttered( + intent={"name": "some_intent", "confidence": 1.0}, + entities=[{"entity": "some_entity", "value": "some_value"}], + ) + ) - slot_values = form.extract_requested_slot(OutputChannel(), nlg, tracker, Domain.empty()) + slot_values = form.extract_requested_slot( + OutputChannel(), nlg, tracker, Domain.empty() + ) assert slot_values == {"some_slot": "some_value"} - tracker.update(UserUttered( - intent={"name": "some_other_intent", "confidence": 1.0}, - entities=[{"entity": "some_entity", "value": "some_value"}] - )) + tracker.update( + UserUttered( + intent={"name": "some_other_intent", "confidence": 1.0}, + entities=[{"entity": "some_entity", "value": "some_value"}], + ) + ) - slot_values = form.extract_requested_slot(OutputChannel(), nlg, tracker, Domain.empty()) + slot_values = form.extract_requested_slot( + OutputChannel(), nlg, tracker, Domain.empty() + ) assert slot_values == {} + def test_extract_requested_slot_from_intent(): """Test extraction of a slot value from certain intent """ spec = { "name": "default_form", - "slots": [{ - "name": "some_slot", - "filling": [{ - "type": "from_intent", - "intent": ["some_intent"], - "value": "some_value" - }] - }] + "slots": [ + { + "name": "some_slot", + "filling": [ + { + "type": "from_intent", + "intent": ["some_intent"], + "value": "some_value", + } + ], + } + ], } form, tracker = new_form_and_tracker(spec, "some_slot") - tracker.update(UserUttered( - intent={"name": "some_intent", "confidence": 1.0} - )) + tracker.update(UserUttered(intent={"name": "some_intent", "confidence": 1.0})) - slot_values = form.extract_requested_slot(OutputChannel(), nlg, tracker, Domain.empty()) + slot_values = form.extract_requested_slot( + OutputChannel(), nlg, tracker, Domain.empty() + ) assert slot_values == {"some_slot": "some_value"} - tracker.update(UserUttered( - intent={"name": "some_other_intent", "confidence": 1.0} - )) + tracker.update(UserUttered(intent={"name": "some_other_intent", "confidence": 1.0})) - slot_values = form.extract_requested_slot(OutputChannel(), nlg, tracker, Domain.empty()) + slot_values = form.extract_requested_slot( + OutputChannel(), nlg, tracker, Domain.empty() + ) assert slot_values == {} + def test_extract_requested_slot_from_text_with_not_intent(): """Test extraction of a slot value from text with certain intent """ spec = { "name": "default_form", - "slots": [{ - "name": "some_slot", - "filling": [{ - "type": "from_text", - "not_intent": ["some_intent"], - }] - }] + "slots": [ + { + "name": "some_slot", + "filling": [{"type": "from_text", "not_intent": ["some_intent"],}], + } + ], } form, tracker = new_form_and_tracker(spec, "some_slot") - tracker.update(UserUttered( - intent={"name": "some_intent", "confidence": 1.0}, - text="some_text" - )) + tracker.update( + UserUttered(intent={"name": "some_intent", "confidence": 1.0}, text="some_text") + ) - slot_values = form.extract_requested_slot(OutputChannel(), nlg, tracker, Domain.empty()) + slot_values = form.extract_requested_slot( + OutputChannel(), nlg, tracker, Domain.empty() + ) assert slot_values == {} - tracker.update(UserUttered( - intent={"name": "some_other_intent", "confidence": 1.0}, - text="some_text" - )) + tracker.update( + UserUttered( + intent={"name": "some_other_intent", "confidence": 1.0}, text="some_text" + ) + ) - slot_values = form.extract_requested_slot(OutputChannel(), nlg, tracker, Domain.empty()) + slot_values = form.extract_requested_slot( + OutputChannel(), nlg, tracker, Domain.empty() + ) assert slot_values == {"some_slot": "some_text"} + @pytest.mark.parametrize( - "operator, value, comparatum, result", [ + "operator, value, comparatum, result", + [ ("is_in", "hey", ["hey", "ho", "fee"], True), ("is_exactly", "aheya", "hey", False), ("contains", "aheya", "hey", True), @@ -184,13 +303,12 @@ def test_extract_requested_slot_from_text_with_not_intent(): async def test_validation(value, operator, comparatum, result, caplog): spec = { "name": "default_form", - "slots": [{ - "name": "some_slot", - "validation": { - "operator": operator, - "comparatum": comparatum, + "slots": [ + { + "name": "some_slot", + "validation": {"operator": operator, "comparatum": comparatum,}, } - }] + ], } form, tracker = new_form_and_tracker(spec, "some_slot") @@ -204,6 +322,71 @@ async def test_validation(value, operator, comparatum, result, caplog): else: assert len(events) == 2 assert isinstance(events[0], SlotSet) and events[0].value == None - assert isinstance(events[1], BotUttered) and events[1].text == "utter_invalid_some_slot" + assert ( + isinstance(events[1], BotUttered) + and events[1].text == "utter_invalid_some_slot" + ) if result is None: assert f"Validation operator '{operator}' requires" in caplog.messages[0] + + +@pytest.mark.parametrize( + "graph, age, authorization_req", + [ + # under 18 or over 65 + (required_slots_graph("OR", False), 17, True), + (required_slots_graph("OR", False), 30, False), + (required_slots_graph("OR", False), 66, True), + # at least 18 and at most 65 + (required_slots_graph("OR", True), 17, False), + (required_slots_graph("OR", True), 30, True), + (required_slots_graph("OR", True), 66, False), + # under 18 and over 65 (contradiction) + (required_slots_graph("AND", False), 17, False), + (required_slots_graph("AND", False), 30, False), + (required_slots_graph("AND", False), 66, False), + # at least 18 or at most 65 (tautology) + (required_slots_graph("AND", True), 17, True), + (required_slots_graph("AND", True), 30, True), + (required_slots_graph("AND", True), 66, True), + ], +) +async def test_required_slots(graph, age, authorization_req): + """ + (start) + | + AGE --- fail age condition --- AUTHORIZATION --- fail authorization + | | condition + | pass authorization | + pass age condition condition | + | / | + | / / + COMMENTS ---------------------- / + | / + (end) --------------------------------------------------- + """ + + spec = {"name": "default_form", "graph_elements": graph} + + form, tracker = new_form_and_tracker( + spec, "age", ["authorization", "comments"] + ) + tracker.update(SlotSet("age", age)) + + # first test with no authorization + tracker.update(SlotSet("authorization", "false")) + assert form.required_slots(tracker) == [ + "age", + *(["authorization"] if authorization_req else []), + # here comments is only required if authorization is not required + *(["comments"] if not authorization_req else []) + ] + + # then with authorization + tracker.update(SlotSet("authorization", "true")) + assert form.required_slots(tracker) == [ + "age", + *(["authorization"] if authorization_req else []), + # now comments is always required + "comments" + ]