Skip to content

Commit

Permalink
Make functions accept str along with Message (#337)
Browse files Browse the repository at this point in the history
# Description

- Made `dff.script.conditions.exact_match` and
`dff.utils.testing.common.check_happy_path` accept `str` along with
`Message`
- Tests added for `exact_match()`, but not for `check_happy_path()`.
- Changed all instances of `exact_match(Message("some text"))` to
`exact_match("some text")`
# Checklist

- [x] I have performed a self-review of the changes

# To Consider

- Update tutorials / guides

---------

Co-authored-by: Roman Zlobin <RLKRo@proton.me>
  • Loading branch information
ZergLev and RLKRo authored Jun 19, 2024
1 parent b5ed32e commit 4c087cf
Show file tree
Hide file tree
Showing 22 changed files with 277 additions and 325 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ import dff.script.conditions.std_conditions as cnd
script = {
GLOBAL: {
TRANSITIONS: {
("flow", "node_hi"): cnd.exact_match(Message("Hi")),
("flow", "node_hi"): cnd.exact_match("Hi"),
("flow", "node_ok"): cnd.true()
}
},
Expand Down
10 changes: 7 additions & 3 deletions dff/script/conditions/std_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,22 @@


@validate_call
def exact_match(match: Message, skip_none: bool = True) -> Callable[[Context, Pipeline], bool]:
def exact_match(match: Union[str, Message], skip_none: bool = True) -> Callable[[Context, Pipeline], bool]:
"""
Return function handler. This handler returns `True` only if the last user phrase
is the same Message as the `match`.
is the same `Message` as the `match`.
If `skip_none` the handler will not compare `None` fields of `match`.
:param match: A Message variable to compare user request with.
:param match: A `Message` variable to compare user request with.
Can also accept `str`, which will be converted into a `Message` with its text field equal to `match`.
:param skip_none: Whether fields should be compared if they are `None` in :py:const:`match`.
"""

def exact_match_condition_handler(ctx: Context, pipeline: Pipeline) -> bool:
request = ctx.last_request
nonlocal match
if isinstance(match, str):
match = Message(text=match)
if request is None:
return False
for field in match.model_fields:
Expand Down
10 changes: 7 additions & 3 deletions dff/utils/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from os import getenv
from typing import Callable, Tuple, Optional
from typing import Callable, Tuple, Optional, Union
from uuid import uuid4

from dff.script import Context, Message
Expand All @@ -32,7 +32,7 @@ def is_interactive_mode() -> bool: # pragma: no cover

def check_happy_path(
pipeline: Pipeline,
happy_path: Tuple[Tuple[Message, Message], ...],
happy_path: Tuple[Tuple[Union[str, Message], Union[str, Message]], ...],
# This optional argument is used for additional processing of candidate responses and reference responses
response_comparer: Callable[[Message, Message, Context], Optional[str]] = default_comparer,
printout_enable: bool = True,
Expand All @@ -51,7 +51,11 @@ def check_happy_path(
"""

ctx_id = uuid4() # get random ID for current context
for step_id, (request, reference_response) in enumerate(happy_path):
for step_id, (request_raw, reference_response_raw) in enumerate(happy_path):
request = Message(text=request_raw) if isinstance(request_raw, str) else request_raw
reference_response = (
Message(text=reference_response_raw) if isinstance(reference_response_raw, str) else reference_response_raw
)
ctx = pipeline(request, ctx_id)
candidate_response = ctx.last_response
if printout_enable:
Expand Down
74 changes: 37 additions & 37 deletions dff/utils/testing/toy_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@
"greeting_flow": {
"start_node": {
RESPONSE: Message(),
TRANSITIONS: {"node1": exact_match(Message("Hi"))},
TRANSITIONS: {"node1": exact_match("Hi")},
},
"node1": {
RESPONSE: Message("Hi, how are you?"),
TRANSITIONS: {"node2": exact_match(Message("i'm fine, how are you?"))},
TRANSITIONS: {"node2": exact_match("i'm fine, how are you?")},
},
"node2": {
RESPONSE: Message("Good. What do you want to talk about?"),
TRANSITIONS: {"node3": exact_match(Message("Let's talk about music."))},
TRANSITIONS: {"node3": exact_match("Let's talk about music.")},
},
"node3": {
RESPONSE: Message("Sorry, I can not talk about music now."),
TRANSITIONS: {"node4": exact_match(Message("Ok, goodbye."))},
TRANSITIONS: {"node4": exact_match("Ok, goodbye.")},
},
"node4": {RESPONSE: Message("bye"), TRANSITIONS: {"node1": exact_match(Message("Hi"))}},
"node4": {RESPONSE: Message("bye"), TRANSITIONS: {"node1": exact_match("Hi")}},
"fallback_node": {
RESPONSE: Message("Ooops"),
TRANSITIONS: {"node1": exact_match(Message("Hi"))},
TRANSITIONS: {"node1": exact_match("Hi")},
},
}
}
Expand All @@ -52,11 +52,11 @@
"""

HAPPY_PATH = (
(Message("Hi"), Message("Hi, how are you?")),
(Message("i'm fine, how are you?"), Message("Good. What do you want to talk about?")),
(Message("Let's talk about music."), Message("Sorry, I can not talk about music now.")),
(Message("Ok, goodbye."), Message("bye")),
(Message("Hi"), Message("Hi, how are you?")),
("Hi", "Hi, how are you?"),
("i'm fine, how are you?", "Good. What do you want to talk about?"),
("Let's talk about music.", "Sorry, I can not talk about music now."),
("Ok, goodbye.", "bye"),
("Hi", "Hi, how are you?"),
)
"""
An example of a simple dialog.
Expand All @@ -69,37 +69,37 @@
"start": {
RESPONSE: Message("Hi"),
TRANSITIONS: {
("small_talk", "ask_some_questions"): exact_match(Message("hi")),
("animals", "have_pets"): exact_match(Message("i like animals")),
("animals", "like_animals"): exact_match(Message("let's talk about animals")),
("news", "what_news"): exact_match(Message("let's talk about news")),
("small_talk", "ask_some_questions"): exact_match("hi"),
("animals", "have_pets"): exact_match("i like animals"),
("animals", "like_animals"): exact_match("let's talk about animals"),
("news", "what_news"): exact_match("let's talk about news"),
},
},
"fallback": {RESPONSE: Message("Oops")},
},
"animals": {
"have_pets": {
RESPONSE: Message("do you have pets?"),
TRANSITIONS: {"what_animal": exact_match(Message("yes"))},
TRANSITIONS: {"what_animal": exact_match("yes")},
},
"like_animals": {
RESPONSE: Message("do you like it?"),
TRANSITIONS: {"what_animal": exact_match(Message("yes"))},
TRANSITIONS: {"what_animal": exact_match("yes")},
},
"what_animal": {
RESPONSE: Message("what animals do you have?"),
TRANSITIONS: {
"ask_about_color": exact_match(Message("bird")),
"ask_about_breed": exact_match(Message("dog")),
"ask_about_color": exact_match("bird"),
"ask_about_breed": exact_match("dog"),
},
},
"ask_about_color": {RESPONSE: Message("what color is it")},
"ask_about_breed": {
RESPONSE: Message("what is this breed?"),
TRANSITIONS: {
"ask_about_breed": exact_match(Message("pereat")),
"tell_fact_about_breed": exact_match(Message("bulldog")),
"ask_about_training": exact_match(Message("I don't know")),
"ask_about_breed": exact_match("pereat"),
"tell_fact_about_breed": exact_match("bulldog"),
"ask_about_training": exact_match("I don't know"),
},
},
"tell_fact_about_breed": {
Expand All @@ -111,53 +111,53 @@
"what_news": {
RESPONSE: Message("what kind of news do you prefer?"),
TRANSITIONS: {
"ask_about_science": exact_match(Message("science")),
"ask_about_sport": exact_match(Message("sport")),
"ask_about_science": exact_match("science"),
"ask_about_sport": exact_match("sport"),
},
},
"ask_about_science": {
RESPONSE: Message("i got news about science, do you want to hear?"),
TRANSITIONS: {
"science_news": exact_match(Message("yes")),
("small_talk", "ask_some_questions"): exact_match(Message("let's change the topic")),
"science_news": exact_match("yes"),
("small_talk", "ask_some_questions"): exact_match("let's change the topic"),
},
},
"science_news": {
RESPONSE: Message("This is science news"),
TRANSITIONS: {
"what_news": exact_match(Message("ok")),
("small_talk", "ask_some_questions"): exact_match(Message("let's change the topic")),
"what_news": exact_match("ok"),
("small_talk", "ask_some_questions"): exact_match("let's change the topic"),
},
},
"ask_about_sport": {
RESPONSE: Message("i got news about sport, do you want to hear?"),
TRANSITIONS: {
"sport_news": exact_match(Message("yes")),
("small_talk", "ask_some_questions"): exact_match(Message("let's change the topic")),
"sport_news": exact_match("yes"),
("small_talk", "ask_some_questions"): exact_match("let's change the topic"),
},
},
"sport_news": {
RESPONSE: Message("This is sport news"),
TRANSITIONS: {
"what_news": exact_match(Message("ok")),
("small_talk", "ask_some_questions"): exact_match(Message("let's change the topic")),
"what_news": exact_match("ok"),
("small_talk", "ask_some_questions"): exact_match("let's change the topic"),
},
},
},
"small_talk": {
"ask_some_questions": {
RESPONSE: Message("how are you"),
TRANSITIONS: {
"ask_talk_about": exact_match(Message("fine")),
("animals", "like_animals"): exact_match(Message("let's talk about animals")),
("news", "what_news"): exact_match(Message("let's talk about news")),
"ask_talk_about": exact_match("fine"),
("animals", "like_animals"): exact_match("let's talk about animals"),
("news", "what_news"): exact_match("let's talk about news"),
},
},
"ask_talk_about": {
RESPONSE: Message("what do you want to talk about"),
TRANSITIONS: {
("animals", "like_animals"): exact_match(Message("dog")),
("news", "what_news"): exact_match(Message("let's talk about news")),
("animals", "like_animals"): exact_match("dog"),
("news", "what_news"): exact_match("let's talk about news"),
},
},
},
Expand Down
16 changes: 8 additions & 8 deletions docs/source/user_guides/basic_conceptions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ Example flow & script
RESPONSE: Message(), # the response of the initial node is skipped
TRANSITIONS: {
("greeting_flow", "greeting_node"):
cnd.exact_match(Message("/start")),
cnd.exact_match("/start"),
},
},
"greeting_node": {
RESPONSE: Message("Hi!"),
TRANSITIONS: {
("ping_pong_flow", "game_start_node"):
cnd.exact_match(Message("Hello!"))
cnd.exact_match("Hello!")
}
},
"fallback_node": {
Expand All @@ -111,14 +111,14 @@ Example flow & script
RESPONSE: Message("Let's play ping-pong!"),
TRANSITIONS: {
("ping_pong_flow", "response_node"):
cnd.exact_match(Message("Ping!")),
cnd.exact_match("Ping!"),
},
},
"response_node": {
RESPONSE: Message("Pong!"),
TRANSITIONS: {
("ping_pong_flow", "response_node"):
cnd.exact_match(Message("Ping!")),
cnd.exact_match("Ping!"),
},
},
},
Expand Down Expand Up @@ -289,9 +289,9 @@ conversational service.
.. code-block:: python
happy_path = (
(Message("/start"), Message("Hi!")),
(Message("Hello!"), Message("Let's play ping-pong!")),
(Message("Ping!"), Message("Pong!"))
("/start", "Hi!"),
("Hello!", "Let's play ping-pong!"),
("Ping!", "Pong!")
)
A special function is then used to ascertain complete identity of the messages taken from
Expand Down Expand Up @@ -384,4 +384,4 @@ Further reading
* `Guide on Context <../user_guides/context_guide.html>`_
* `Tutorial on global transitions <../tutorials/tutorials.script.core.5_global_transitions.html>`_
* `Tutorial on context serialization <../tutorials/tutorials.script.core.6_context_serialization.html>`_
* `Tutorial on script MISC <../tutorials/tutorials.script.core.8_misc.html>`_
* `Tutorial on script MISC <../tutorials/tutorials.script.core.8_misc.html>`_
6 changes: 3 additions & 3 deletions tests/pipeline/test_messenger_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@
RESPONSE: {
"text": "",
},
TRANSITIONS: {"node1": cnd.exact_match(Message("Ping"))},
TRANSITIONS: {"node1": cnd.exact_match("Ping")},
},
"node1": {
RESPONSE: {
"text": "Pong",
},
TRANSITIONS: {"node1": cnd.exact_match(Message("Ping"))},
TRANSITIONS: {"node1": cnd.exact_match("Ping")},
},
"fallback_node": {
RESPONSE: {
"text": "Ooops",
},
TRANSITIONS: {"node1": cnd.exact_match(Message("Ping"))},
TRANSITIONS: {"node1": cnd.exact_match("Ping")},
},
}
}
Expand Down
22 changes: 12 additions & 10 deletions tests/script/conditions/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ def test_conditions():
failed_ctx.add_label(label)
pipeline = Pipeline.from_script(script={"flow": {"node": {}}}, start_label=("flow", "node"))

assert cnd.exact_match(Message("text"))(ctx, pipeline)
assert cnd.exact_match("text")(ctx, pipeline)
assert cnd.exact_match(Message("text", misc={}))(ctx, pipeline)
assert not cnd.exact_match(Message("text", misc={1: 1}))(ctx, pipeline)
assert not cnd.exact_match(Message("text1"))(ctx, pipeline)
assert not cnd.exact_match("text1")(ctx, pipeline)
assert cnd.exact_match(Message())(ctx, pipeline)
assert not cnd.exact_match(Message(), skip_none=False)(ctx, pipeline)
assert cnd.exact_match("text")(ctx, pipeline)
assert not cnd.exact_match("text1")(ctx, pipeline)

assert cnd.has_text("text")(ctx, pipeline)
assert cnd.has_text("te")(ctx, pipeline)
Expand All @@ -30,17 +32,17 @@ def test_conditions():
assert not cnd.regexp("t.*t1")(ctx, pipeline)
assert not cnd.regexp("t.*t1")(failed_ctx, pipeline)

assert cnd.agg([cnd.regexp("t.*t"), cnd.exact_match(Message("text"))], aggregate_func=all)(ctx, pipeline)
assert not cnd.agg([cnd.regexp("t.*t1"), cnd.exact_match(Message("text"))], aggregate_func=all)(ctx, pipeline)
assert cnd.agg([cnd.regexp("t.*t"), cnd.exact_match("text")], aggregate_func=all)(ctx, pipeline)
assert not cnd.agg([cnd.regexp("t.*t1"), cnd.exact_match("text")], aggregate_func=all)(ctx, pipeline)

assert cnd.any([cnd.regexp("t.*t1"), cnd.exact_match(Message("text"))])(ctx, pipeline)
assert not cnd.any([cnd.regexp("t.*t1"), cnd.exact_match(Message("text1"))])(ctx, pipeline)
assert cnd.any([cnd.regexp("t.*t1"), cnd.exact_match("text")])(ctx, pipeline)
assert not cnd.any([cnd.regexp("t.*t1"), cnd.exact_match("text1")])(ctx, pipeline)

assert cnd.all([cnd.regexp("t.*t"), cnd.exact_match(Message("text"))])(ctx, pipeline)
assert not cnd.all([cnd.regexp("t.*t1"), cnd.exact_match(Message("text"))])(ctx, pipeline)
assert cnd.all([cnd.regexp("t.*t"), cnd.exact_match("text")])(ctx, pipeline)
assert not cnd.all([cnd.regexp("t.*t1"), cnd.exact_match("text")])(ctx, pipeline)

assert cnd.neg(cnd.exact_match(Message("text1")))(ctx, pipeline)
assert not cnd.neg(cnd.exact_match(Message("text")))(ctx, pipeline)
assert cnd.neg(cnd.exact_match("text1"))(ctx, pipeline)
assert not cnd.neg(cnd.exact_match("text"))(ctx, pipeline)

assert cnd.has_last_labels(flow_labels=["flow"])(ctx, pipeline)
assert not cnd.has_last_labels(flow_labels=["flow1"])(ctx, pipeline)
Expand Down
3 changes: 1 addition & 2 deletions tests/tutorials/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import re

import pytest
from dff.script import Message
from dff.utils.testing.common import check_happy_path, is_interactive_mode
from tests.pipeline.test_messenger_interface import pipeline


def test_unhappy_path():
with pytest.raises(Exception) as e:
check_happy_path(pipeline, ((Message("Hi"), Message("false_response")),))
check_happy_path(pipeline, (("Hi", "false_response"),))
assert e
msg = str(e)
assert msg
Expand Down
4 changes: 2 additions & 2 deletions tutorials/messengers/telegram/1_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ class and [telebot](https://pytba.readthedocs.io/en/latest/index.html)
script = {
"greeting_flow": {
"start_node": {
TRANSITIONS: {"greeting_node": cnd.exact_match(Message("/start"))},
TRANSITIONS: {"greeting_node": cnd.exact_match("/start")},
},
"greeting_node": {
RESPONSE: Message("Hi"),
TRANSITIONS: {lbl.repeat(): cnd.true()},
},
"fallback_node": {
RESPONSE: Message("Please, repeat the request"),
TRANSITIONS: {"greeting_node": cnd.exact_match(Message("/start"))},
TRANSITIONS: {"greeting_node": cnd.exact_match("/start")},
},
}
}
Expand Down
Loading

0 comments on commit 4c087cf

Please sign in to comment.