Skip to content
This repository has been archived by the owner on Apr 28, 2021. It is now read-only.

Commit

Permalink
feat: programmatic required slots
Browse files Browse the repository at this point in the history
  • Loading branch information
pheel authored Sep 18, 2020
2 parents 271a723 + b9a9cda commit a2b6992
Show file tree
Hide file tree
Showing 4 changed files with 347 additions and 113 deletions.
61 changes: 28 additions & 33 deletions rasa_addons/core/actions/action_botfront_form.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import logging
import functools
from typing import Dict, Text, Any, List, Union, Optional, Tuple
from typing import Dict, Text, Any, List, Optional
from rasa.core.slots import Slot

from rasa_addons.core.actions.required_slots_graph_parser import (
RequiredSlotsGraphParser,
)
from rasa_addons.core.actions.slot_rule_validator import validate_with_rule
from rasa_addons.core.actions.submit_form_to_botfront import submit_form_to_botfront

Expand Down Expand Up @@ -41,7 +46,12 @@ def __str__(self) -> Text:
return f"FormAction('{self.name()}')"

def required_slots(self, tracker):
return [s.get("name") for s in self.form_spec.get("slots", [])]
graph = self.form_spec.get("graph_elements")
if not graph:
return [s.get("name") for s in self.form_spec.get("slots", [])]
parser = RequiredSlotsGraphParser(graph)
required_slots = parser.get_required_slots(tracker)
return required_slots

def get_field_for_slot(
self, slot: Text, field: Text, default: Optional[Any] = None,
Expand All @@ -51,32 +61,6 @@ def get_field_for_slot(
return s.get(field, default)
return default

async def validate_prefilled(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
):
# collect values of required slots filled before activation
prefilled_slots = {}
events = []

for slot_name in self.required_slots(tracker):
if not self._should_request_slot(tracker, slot_name):
prefilled_slots[slot_name] = tracker.get_slot(slot_name)

if prefilled_slots:
logger.debug(f"Validating pre-filled required slots: {prefilled_slots}")
events.extend(
await self.validate_slots(
prefilled_slots, output_channel, nlg, tracker, domain
)
)
else:
logger.debug("No pre-filled required slots to validate.")
return events

async def run(
self,
output_channel: "OutputChannel",
Expand All @@ -86,7 +70,9 @@ async def run(
) -> List[Event]:
# attempt retrieving spec
if not len(self.form_spec):
for form in tracker.slots.get("bf_forms", {}).initial_value:
for form in tracker.slots.get(
"bf_forms", Slot("bf_forms", initial_value=[])
).initial_value:
if form.get("name") == self.name():
self.form_spec = clean_none_values(form)
if not len(self.form_spec):
Expand Down Expand Up @@ -150,7 +136,8 @@ async def submit(
template = await nlg.generate(
f"utter_submit_{self.name()}", tracker, output_channel.name(),
)
events += [create_bot_utterance(template)]
if template:
events += [create_bot_utterance(template)]
if collect_in_botfront:
submit_form_to_botfront(tracker)
return events
Expand Down Expand Up @@ -222,11 +209,13 @@ def entity_is_desired(

@staticmethod
def get_entity_value(
name: Text,
name: Optional[Text],
tracker: "DialogueStateTracker",
role: Optional[Text] = None,
group: Optional[Text] = None,
) -> Any:
if not name:
return None
# list is used to cover the case of list slot type
value = list(
tracker.get_latest_entity_values(name, entity_group=group, entity_role=role)
Expand All @@ -246,6 +235,8 @@ def extract_other_slots(
domain: "Domain",
) -> Dict[Text, Any]:
slot_to_fill = tracker.get_slot(REQUESTED_SLOT)
if not slot_to_fill:
return {}

slot_values = {}
for slot in self.required_slots(tracker):
Expand Down Expand Up @@ -300,6 +291,8 @@ def extract_requested_slot(
else return None
"""
slot_to_fill = tracker.get_slot(REQUESTED_SLOT)
if not slot_to_fill:
return {}
logger.debug(f"Trying to extract requested slot '{slot_to_fill}' ...")

# get mapping for requested slot
Expand Down Expand Up @@ -352,14 +345,16 @@ async def utter_post_validation(
and self.get_field_for_slot(slot, "utter_on_new_valid_slot", False) is False
):
return []
valid = "valid" if valid else "invalid"
utter_what = "valid" if valid else "invalid"

# so utter_(in)valid_slot supports {slot} template replacements
temp_tracker = tracker.copy()
temp_tracker.slots[slot].value = value
template = await nlg.generate(
f"utter_{valid}_{slot}", temp_tracker, output_channel.name(),
f"utter_{utter_what}_{slot}", temp_tracker, output_channel.name(),
)
if not template:
return []
return [create_bot_utterance(template)]

async def validate_slots(
Expand Down
56 changes: 56 additions & 0 deletions rasa_addons/core/actions/required_slots_graph_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Dict, Text, Any
from rasa_addons.core.actions.slot_rule_validator import validate_with_rule


class RequiredSlotsGraphParser:
def __init__(self, required_slots_graph: Dict[Text, Any]) -> None:
self.start = None
self.nodes = {}
for node in required_slots_graph.get("nodes", []):
if node.get("type") == "start":
self.start = node.get("id")
continue
self.nodes[node.get("id")] = node.get("slotName")
self.edges = {}
for edge in required_slots_graph.get("edges", []):
source = edge.get("source")
self.edges[source] = [*self.edges.get(source, []), edge]

def get_required_slots(self, tracker, start=None):
required_slots = []
current_source = start or self.start
current_edges = self.edges.get(current_source, [])
for edge in sorted(current_edges, key=lambda e: e.get("condition") is None):
target, condition = edge.get("target"), edge.get("condition")
if self.check_condition(tracker, condition):
required_slots.append(self.nodes.get(target))
required_slots += self.get_required_slots(tracker, start=target)
break # use first matching condition, that's it
else:
continue
return required_slots

def check_condition(self, tracker, condition):
if condition is None:
return True
props = condition.get("properties", {})
children = condition.get("children1", {}).values()
if condition.get("type") == "rule":
return self.check_atomic_condition(tracker, **props)
conjunction_operator = any if props.get("conjunction") == "OR" else all
polarity = (lambda p: not p) if props.get("not") else (lambda p: p)
return polarity(
conjunction_operator(
self.check_condition(tracker, child) for child in children
)
)

def check_atomic_condition(self, tracker, field, operator, value, **rest):
slot = tracker.slots.get(field)
return validate_with_rule(
slot.value if slot else None,
{
"operator": operator,
"comparatum": [*value, None][0] # value is always a singleton list
},
)
6 changes: 3 additions & 3 deletions rasa_addons/core/actions/slot_rule_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,19 @@ def validate_with_rule(value, validation_rule) -> bool:
if operator in NUM_COMPARATUM_OPERATORS:
try:
comparatum = float(comparatum)
except ValueError:
except (ValueError, TypeError):
raise ValueError(
f"Validation operator '{operator}' requires a numerical comparatum."
)
except ValueError as e:
except (ValueError, TypeError) as e:
logger.error(str(e))
return False
if operator in TEXT_VALUE_OPERATORS and not isinstance(value, str):
return False
if operator in NUM_VALUE_OPERATORS:
try:
value = float(value)
except ValueError:
except (ValueError, TypeError):
return False
if operator == "is_in":
return value in comparatum
Expand Down
Loading

0 comments on commit a2b6992

Please sign in to comment.