From 16999061268ba4958d4f2146eadc430e04709e8a Mon Sep 17 00:00:00 2001 From: m-vdb Date: Mon, 31 Aug 2020 12:22:35 +0200 Subject: [PATCH 01/42] cleanup type error: [dict-item] --- rasa/core/actions/forms.py | 7 ++++--- rasa/core/channels/hangouts.py | 2 +- rasa/nlu/classifiers/diet_classifier.py | 2 +- rasa/nlu/training_data/lookup_tables_parser.py | 8 ++++++-- setup.cfg | 3 --- 5 files changed, 12 insertions(+), 10 deletions(-) diff --git a/rasa/core/actions/forms.py b/rasa/core/actions/forms.py index 67bfdff83c88..1c48aab9550c 100644 --- a/rasa/core/actions/forms.py +++ b/rasa/core/actions/forms.py @@ -249,12 +249,11 @@ def extract_other_slots( return slot_values def extract_requested_slot( - self, tracker: "DialogueStateTracker", domain: Domain + self, tracker: "DialogueStateTracker", domain: Domain, slot_to_fill: Text, ) -> Dict[Text, Any]: """Extract the value of requested slot from a user input else return `None`. """ - slot_to_fill = tracker.get_slot(REQUESTED_SLOT) logger.debug(f"Trying to extract requested slot '{slot_to_fill}' ...") # get mapping for requested slot @@ -377,7 +376,9 @@ async def validate( # extract requested slot slot_to_fill = tracker.get_slot(REQUESTED_SLOT) if slot_to_fill: - slot_values.update(self.extract_requested_slot(tracker, domain)) + slot_values.update( + self.extract_requested_slot(tracker, domain, slot_to_fill) + ) if not slot_values: # reject to execute the form action diff --git a/rasa/core/channels/hangouts.py b/rasa/core/channels/hangouts.py index 8377680ff92d..078347ce88fc 100644 --- a/rasa/core/channels/hangouts.py +++ b/rasa/core/channels/hangouts.py @@ -132,7 +132,7 @@ async def _persist_message(self, message: Dict) -> None: new_messages = self._combine_cards(text_card, message) elif msg_new == "text": - new_messages = {"text": message.get("text")} + new_messages = {"text": message["text"]} else: new_messages = message diff --git a/rasa/nlu/classifiers/diet_classifier.py b/rasa/nlu/classifiers/diet_classifier.py index 4ba2f49ab0d8..c7d66d764ce3 100644 --- a/rasa/nlu/classifiers/diet_classifier.py +++ b/rasa/nlu/classifiers/diet_classifier.py @@ -851,7 +851,7 @@ def _predict_label( ) -> Tuple[Dict[Text, Any], List[Dict[Text, Any]]]: """Predicts the intent of the provided message.""" - label = {"name": None, "id": None, "confidence": 0.0} + label: Dict[Text, Any] = {"name": None, "id": None, "confidence": 0.0} label_ranking = [] if predict_out is None: diff --git a/rasa/nlu/training_data/lookup_tables_parser.py b/rasa/nlu/training_data/lookup_tables_parser.py index c860b259ea80..9f2b6969c42f 100644 --- a/rasa/nlu/training_data/lookup_tables_parser.py +++ b/rasa/nlu/training_data/lookup_tables_parser.py @@ -1,8 +1,10 @@ -from typing import Any, Text, List, Dict +from typing import Any, Text, List, Dict, Union def add_item_to_lookup_tables( - title: Text, item: Text, existing_lookup_tables: List[Dict[Text, List[Text]]] + title: Text, + item: Text, + existing_lookup_tables: List[Dict[Text, Union[Text, List[Text]]]], ) -> None: """Takes a list of lookup table dictionaries. Finds the one associated with the current lookup, then adds the item to the list. @@ -17,4 +19,6 @@ def add_item_to_lookup_tables( existing_lookup_tables.append({"name": title, "elements": [item]}) else: elements = matches[0]["elements"] + if not isinstance(elements, list): + elements = matches[0]["elements"] = [elements] elements.append(item) diff --git a/setup.cfg b/setup.cfg index 2bc2e31910f6..6b79289b6d9e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -380,9 +380,6 @@ ignore_errors = True [mypy-rasa.nlu.training_data.loading] ignore_errors = True -[mypy-rasa.nlu.training_data.lookup_tables_parser] -ignore_errors = True - [mypy-rasa.nlu.training_data.message] ignore_errors = True From 0da7e776cf837a5289c04c3362ad5678f143d1ed Mon Sep 17 00:00:00 2001 From: m-vdb Date: Mon, 31 Aug 2020 14:53:48 +0200 Subject: [PATCH 02/42] cleanup type error: [list-item] --- rasa/core/actions/action.py | 11 +++++++---- rasa/core/actions/two_stage_fallback.py | 4 ++-- rasa/core/featurizers.py | 2 +- rasa/core/policies/policy.py | 3 +-- rasa/core/policies/rule_policy.py | 14 +++++++------- rasa/core/test.py | 7 ++++--- rasa/core/training/interactive.py | 12 ++++++------ 7 files changed, 28 insertions(+), 25 deletions(-) diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index 45f38061b31b..2bc1c06de466 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -677,10 +677,6 @@ def has_user_affirmed(tracker: "DialogueStateTracker") -> bool: def _revert_affirmation_events(tracker: "DialogueStateTracker") -> List[Event]: revert_events = _revert_single_affirmation_events() - last_user_event = tracker.get_last_event_for(UserUttered) - last_user_event = copy.deepcopy(last_user_event) - last_user_event.parse_data["intent"]["confidence"] = 1.0 - # User affirms the rephrased intent rephrased_intent = tracker.last_executed_action_has( name=ACTION_DEFAULT_ASK_REPHRASE_NAME, skip=1 @@ -688,6 +684,13 @@ def _revert_affirmation_events(tracker: "DialogueStateTracker") -> List[Event]: if rephrased_intent: revert_events += _revert_rephrasing_events() + last_user_event = tracker.get_last_event_for(UserUttered) + if not last_user_event: + return revert_events + + last_user_event = copy.deepcopy(last_user_event) + last_user_event.parse_data["intent"]["confidence"] = 1.0 + return revert_events + [last_user_event] diff --git a/rasa/core/actions/two_stage_fallback.py b/rasa/core/actions/two_stage_fallback.py index 0a7db875ed27..27406cc557de 100644 --- a/rasa/core/actions/two_stage_fallback.py +++ b/rasa/core/actions/two_stage_fallback.py @@ -1,6 +1,6 @@ import copy import time -from typing import List, Text, Optional +from typing import List, Text, Optional, cast from rasa.constants import DEFAULT_NLU_FALLBACK_INTENT_NAME from rasa.core.actions import action @@ -195,7 +195,7 @@ def _second_affirmation_failed(tracker: DialogueStateTracker) -> bool: def _message_clarification(tracker: DialogueStateTracker) -> List[Event]: - clarification = copy.deepcopy(tracker.latest_message) + clarification = copy.deepcopy(cast(Event, tracker.latest_message)) clarification.parse_data["intent"]["confidence"] = 1.0 clarification.timestamp = time.time() return [ActionExecuted(ACTION_LISTEN_NAME), clarification] diff --git a/rasa/core/featurizers.py b/rasa/core/featurizers.py index 39f12f086d90..8cf8e401f609 100644 --- a/rasa/core/featurizers.py +++ b/rasa/core/featurizers.py @@ -563,7 +563,7 @@ def __init__( ) -> None: super().__init__(state_featurizer, use_intent_probabilities) - self.max_history = max_history or self.MAX_HISTORY_DEFAULT + self.max_history: Optional[int] = max_history or self.MAX_HISTORY_DEFAULT self.remove_duplicates = remove_duplicates @staticmethod diff --git a/rasa/core/policies/policy.py b/rasa/core/policies/policy.py index 026bb207febe..a1e7e873530f 100644 --- a/rasa/core/policies/policy.py +++ b/rasa/core/policies/policy.py @@ -208,8 +208,7 @@ def load(cls, path: Text) -> "Policy": raise NotImplementedError("Policy must have the capacity to load itself.") - @staticmethod - def _default_predictions(domain: Domain) -> List[float]: + def _default_predictions(self, domain: Domain) -> List[float]: """Creates a list of zeros. Args: diff --git a/rasa/core/policies/rule_policy.py b/rasa/core/policies/rule_policy.py index cf1f4d5028f4..a9ebc23abacf 100644 --- a/rasa/core/policies/rule_policy.py +++ b/rasa/core/policies/rule_policy.py @@ -1,5 +1,5 @@ import logging -from typing import List, Dict, Text, Optional, Any, Set, TYPE_CHECKING +from typing import List, Dict, Text, Optional, Any, Set, TYPE_CHECKING, Union import re from collections import defaultdict @@ -152,8 +152,8 @@ def _prev_action_listen_in_state(state: Dict[Text, float]) -> bool: @staticmethod def _modified_states( - states: List[Dict[Text, float]] - ) -> List[Optional[Dict[Text, float]]]: + states: List[Dict[Text, Union[int, float]]] + ) -> List[Optional[Dict[Text, Union[int, float]]]]: """Modifies the states to create feature keys for form unhappy path conditions. Args: @@ -165,7 +165,7 @@ def _modified_states( """ indicator = PREV_PREFIX + RULE_SNIPPET_ACTION_NAME - state_only_with_action = {indicator: 1} + state_only_with_action: Dict[Text, Union[int, float]] = {indicator: 1} # leave only last 2 dialogue turns to # - capture previous meaningful action before action_listen # - ignore previous intent @@ -266,9 +266,7 @@ def train( # only consider original trackers (no augmented ones) training_trackers = [ - t - for t in training_trackers - if not hasattr(t, "is_augmented") or not t.is_augmented + t for t in training_trackers if not getattr(t, "is_augmented", False) ] # only use trackers from rule-based training data rule_trackers = [t for t in training_trackers if t.is_rule_tracker] @@ -414,6 +412,8 @@ def _find_action_from_form_happy_path( ) return ACTION_LISTEN_NAME + return None + def _find_action_from_rules( self, tracker: DialogueStateTracker, domain: Domain ) -> Optional[Text]: diff --git a/rasa/core/test.py b/rasa/core/test.py index 92cda8dc2c04..c654829fa0c2 100644 --- a/rasa/core/test.py +++ b/rasa/core/test.py @@ -275,9 +275,10 @@ def _collect_user_uttered_predictions( intent_gold = event.intent.get("name") predicted_intent = predicted.get(INTENT, {}).get("name") - user_uttered_eval_store.add_to_store( - intent_predictions=[predicted_intent], intent_targets=[intent_gold] - ) + if intent_gold: + user_uttered_eval_store.add_to_store(intent_targets=[intent_gold]) + if predicted_intent: + user_uttered_eval_store.add_to_store(intent_predictions=[predicted_intent]) entity_gold = event.entities predicted_entities = predicted.get(ENTITIES) diff --git a/rasa/core/training/interactive.py b/rasa/core/training/interactive.py index 3ac7560369ee..de0ddc7f728d 100644 --- a/rasa/core/training/interactive.py +++ b/rasa/core/training/interactive.py @@ -6,7 +6,7 @@ import uuid from functools import partial from multiprocessing import Process -from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union, Set +from typing import Any, Callable, Deque, Dict, List, Optional, Text, Tuple, Union, Set import numpy as np from aiohttp import ClientError @@ -1474,7 +1474,7 @@ async def record_messages( async def _get_tracker_events_to_plot( domain: Dict[Text, Any], file_importer: TrainingDataImporter, conversation_id: Text -) -> List[Union[Text, List[Event]]]: +) -> List[Union[Text, Deque[Event]]]: training_trackers = await _get_training_trackers(file_importer, domain) number_of_trackers = len(training_trackers) if number_of_trackers > MAX_NUMBER_OF_TRAINING_STORIES_FOR_VISUALIZATION: @@ -1487,10 +1487,10 @@ async def _get_tracker_events_to_plot( ) training_trackers = [] - training_data_events = [t.events for t in training_trackers] - events_including_current_user_id = training_data_events + [conversation_id] - - return events_including_current_user_id + training_data_events: List[Union[Text, Deque[Event]]] = [ + t.events for t in training_trackers + ] + return training_data_events + [conversation_id] async def _get_training_trackers( From a797f7b81202c947426e5280662588d5d05341c4 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Mon, 31 Aug 2020 17:35:02 +0200 Subject: [PATCH 03/42] cleanup type error: [call-arg] --- poetry.lock | 3 +-- pyproject.toml | 1 + rasa/cli/train.py | 20 +++++++------------- rasa/server.py | 27 +++++++++++++++++++++------ rasa/test.py | 6 +++--- setup.cfg | 3 --- 6 files changed, 33 insertions(+), 27 deletions(-) diff --git a/poetry.lock b/poetry.lock index d10518cb0f2b..8e3d4801c4a7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3053,7 +3053,6 @@ version = ">=3.7.4" [[package]] category = "main" description = "Backport of pathlib-compatible object wrapper for zip files" -marker = "sys_platform != \"win32\" and python_version < \"3.8\" or python_version < \"3.8\" or python_version < \"3.7\" and python_version != \"3.4\"" name = "zipp" optional = false python-versions = ">=3.6" @@ -3071,7 +3070,7 @@ spacy = ["spacy"] transformers = ["transformers"] [metadata] -content-hash = "01760c0b388a7deeceb6127efa0dc9f8c3a495e7046e2776e3073d1919e66867" +content-hash = "7c826413e4ee6a6c55df8f25b5b41217b75f9235c9305892c1fb112beabcbe35" python-versions = ">=3.6,<3.9" [metadata.files] diff --git a/pyproject.toml b/pyproject.toml index f40d42e13482..a917d70317ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,6 +150,7 @@ toml = "^0.10.0" pep440-version-utils = "^0.3.0" pydoc-markdown = "3.3.0.post1" mypy = "^0.782" +typing-extensions = "^3.7.4" [tool.poetry.extras] spacy = [ "spacy",] diff --git a/rasa/cli/train.py b/rasa/cli/train.py index 0ffff340f074..d5bb21726216 100644 --- a/rasa/cli/train.py +++ b/rasa/cli/train.py @@ -1,9 +1,12 @@ import argparse +import asyncio import os from typing import List, Optional, Text, Dict -import rasa.cli.arguments.train as train_arguments +import rasa.train +import rasa.cli.arguments.train as train_arguments from rasa.cli.utils import get_validated_path, missing_config_keys, print_error +from rasa.core.train import do_compare_training from rasa.constants import ( DEFAULT_CONFIG_PATH, DEFAULT_DATA_PATH, @@ -52,8 +55,6 @@ def add_subparser( def train(args: argparse.Namespace) -> Optional[Text]: - import rasa - domain = get_validated_path( args.domain, "domain", DEFAULT_DOMAIN_PATH, none_is_valid=True ) @@ -65,7 +66,7 @@ def train(args: argparse.Namespace) -> Optional[Text]: for f in args.data ] - return rasa.train( + return rasa.train.train( domain=domain, config=config, training_files=training_files, @@ -81,9 +82,6 @@ def train(args: argparse.Namespace) -> Optional[Text]: def train_core( args: argparse.Namespace, train_path: Optional[Text] = None ) -> Optional[Text]: - from rasa.train import train_core - import asyncio - loop = asyncio.get_event_loop() output = train_path or args.out @@ -103,7 +101,7 @@ def train_core( config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS_CORE) - return train_core( + return rasa.train.train_core( domain=args.domain, config=config, stories=story_file, @@ -113,8 +111,6 @@ def train_core( additional_arguments=additional_arguments, ) else: - from rasa.core.train import do_compare_training - loop.run_until_complete( do_compare_training(args, story_file, additional_arguments) ) @@ -123,8 +119,6 @@ def train_core( def train_nlu( args: argparse.Namespace, train_path: Optional[Text] = None ) -> Optional[Text]: - from rasa.train import train_nlu - output = train_path or args.out config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS_NLU) @@ -132,7 +126,7 @@ def train_nlu( args.nlu, "nlu", DEFAULT_DATA_PATH, none_is_valid=True ) - return train_nlu( + return rasa.train.train_nlu( config=config, nlu_data=nlu_data, output=output, diff --git a/rasa/server.py b/rasa/server.py index ffc4d9e7b0b1..8738416e21de 100644 --- a/rasa/server.py +++ b/rasa/server.py @@ -9,7 +9,7 @@ from functools import reduce, wraps from inspect import isawaitable from pathlib import Path -from typing import Any, Callable, List, Optional, Text, Union, Dict +from typing import Any, Callable, List, Optional, Text, Union, Dict, cast from rasa.core.training.story_writer.yaml_story_writer import YAMLStoryWriter from rasa.nlu.training_data.formats import RasaYAMLReader @@ -52,8 +52,16 @@ if typing.TYPE_CHECKING: from ssl import SSLContext + from typing_extensions import Protocol from rasa.core.processor import MessageProcessor + class SanicView(Protocol): + def __call__( + self, request: Request, *args: Any, **kwargs: Any + ) -> response.BaseHTTPResponse: + ... + + logger = logging.getLogger(__name__) JSON_CONTENT_TYPE = "application/json" @@ -123,7 +131,7 @@ def decorated(*args, **kwargs): def requires_auth(app: Sanic, token: Optional[Text] = None) -> Callable[[Any], Any]: """Wraps a request handler with token authentication.""" - def decorator(f: Callable[[Any, Any], Any]) -> Callable[[Any, Any], Any]: + def decorator(f: "SanicView") -> "SanicView": def conversation_id_from_args(args: Any, kwargs: Any) -> Optional[Text]: argnames = common_utils.arguments_of(f) @@ -153,7 +161,9 @@ def sufficient_scope(request, *args: Any, **kwargs: Any) -> Optional[bool]: return False @wraps(f) - async def decorated(request: Request, *args: Any, **kwargs: Any) -> Any: + async def decorated( + request: Request, *args: Any, **kwargs: Any + ) -> response.BaseHTTPResponse: provided = request.args.get("token", None) @@ -227,7 +237,7 @@ async def get_tracker( _validate_tracker(tracker, conversation_id) # `_validate_tracker` ensures we can't return `None` so `Optional` is not needed - return tracker # pytype: disable=bad-return-type + return cast(DialogueStateTracker, tracker) def _validate_tracker( @@ -631,7 +641,7 @@ async def execute_action(request: Request, conversation_id: Text): tracker = await get_tracker(app.agent.create_processor(), conversation_id) state = tracker.current_state(verbosity) - response_body = {"tracker": state} + response_body: Dict[Text, Any] = {"tracker": state} if isinstance(output_channel, CollectingOutputChannel): response_body["messages"] = output_channel.messages @@ -685,7 +695,7 @@ async def trigger_intent(request: Request, conversation_id: Text) -> HTTPRespons state = tracker.current_state(verbosity) - response_body = {"tracker": state} + response_body: Dict[Text, Any] = {"tracker": state} if isinstance(output_channel, CollectingOutputChannel): response_body["messages"] = output_channel.messages @@ -870,6 +880,11 @@ async def evaluate_intents(request: Request) -> HTTPResponse: model_directory = eval_agent.model_directory _, nlu_model = model.get_model_subdirectories(model_directory) + if nlu_model is None: + raise ErrorResponse( + 500, "TestingError", "Missing NLU model directory.", + ) + try: evaluation = run_evaluation(data_path, nlu_model, disable_plotting=True) return response.json(evaluation) diff --git a/rasa/test.py b/rasa/test.py index bdeee1ff796c..cac9fefb8ead 100644 --- a/rasa/test.py +++ b/rasa/test.py @@ -147,7 +147,7 @@ def test_core( "to train a NLU model first, e.g. using `rasa train`." ) - from rasa.core.test import test + from rasa.core.test import test as core_test kwargs = utils.minimal_kwargs(additional_arguments, test, ["stories", "agent"]) @@ -157,11 +157,11 @@ def test_core( def _test_core( stories: Optional[Text], agent: "Agent", output_directory: Text, **kwargs: Any ) -> None: - from rasa.core.test import test + from rasa.core.test import test as core_test loop = asyncio.get_event_loop() loop.run_until_complete( - test(stories, agent, out_directory=output_directory, **kwargs) + core_test(stories, agent, out_directory=output_directory, **kwargs) ) diff --git a/setup.cfg b/setup.cfg index 6b79289b6d9e..81ebaaab50fc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -398,9 +398,6 @@ ignore_errors = True [mypy-rasa.run] ignore_errors = True -[mypy-rasa.server] -ignore_errors = True - [mypy-rasa.test] ignore_errors = True From 4e0fa91406a2d27aaae5135614659c96f6ea253f Mon Sep 17 00:00:00 2001 From: m-vdb Date: Mon, 31 Aug 2020 18:24:54 +0200 Subject: [PATCH 04/42] fix import errors --- rasa/cli/train.py | 12 ++++++++---- rasa/test.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/rasa/cli/train.py b/rasa/cli/train.py index d5bb21726216..8f222d0380c8 100644 --- a/rasa/cli/train.py +++ b/rasa/cli/train.py @@ -3,8 +3,12 @@ import os from typing import List, Optional, Text, Dict -import rasa.train import rasa.cli.arguments.train as train_arguments +from rasa.train import ( + train as rasa_train, + train_core as rasa_train_core, + train_nlu as rasa_train_nlu, +) from rasa.cli.utils import get_validated_path, missing_config_keys, print_error from rasa.core.train import do_compare_training from rasa.constants import ( @@ -66,7 +70,7 @@ def train(args: argparse.Namespace) -> Optional[Text]: for f in args.data ] - return rasa.train.train( + return rasa_train( domain=domain, config=config, training_files=training_files, @@ -101,7 +105,7 @@ def train_core( config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS_CORE) - return rasa.train.train_core( + return rasa_train_core( domain=args.domain, config=config, stories=story_file, @@ -126,7 +130,7 @@ def train_nlu( args.nlu, "nlu", DEFAULT_DATA_PATH, none_is_valid=True ) - return rasa.train.train_nlu( + return rasa_train_nlu( config=config, nlu_data=nlu_data, output=output, diff --git a/rasa/test.py b/rasa/test.py index cac9fefb8ead..e8f1c461b846 100644 --- a/rasa/test.py +++ b/rasa/test.py @@ -149,7 +149,7 @@ def test_core( from rasa.core.test import test as core_test - kwargs = utils.minimal_kwargs(additional_arguments, test, ["stories", "agent"]) + kwargs = utils.minimal_kwargs(additional_arguments, core_test, ["stories", "agent"]) _test_core(stories, _agent, output, **kwargs) From 19811418695e3c5957c04ffca777d157c2facefc Mon Sep 17 00:00:00 2001 From: m-vdb Date: Tue, 1 Sep 2020 09:37:03 +0200 Subject: [PATCH 05/42] fix tests after import/signature changes --- tests/cli/test_rasa_interactive.py | 10 ++++------ tests/core/actions/test_forms.py | 10 +++++----- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/tests/cli/test_rasa_interactive.py b/tests/cli/test_rasa_interactive.py index 7ae7164e2bd7..5fac169a9880 100644 --- a/tests/cli/test_rasa_interactive.py +++ b/tests/cli/test_rasa_interactive.py @@ -1,12 +1,13 @@ import argparse -import pytest from typing import Callable, Text from unittest.mock import Mock, ANY +import pytest from _pytest.monkeypatch import MonkeyPatch from _pytest.pytester import RunResult -import rasa +from rasa.core.train import do_interactive_learning +from rasa.core.training import interactive as interactive_learning from rasa.cli import interactive, train from tests.conftest import DEFAULT_NLU_DATA @@ -59,7 +60,7 @@ def test_pass_arguments_to_rasa_train( # Mock actual training mock = Mock() - monkeypatch.setattr(rasa, "train", mock.method) + monkeypatch.setattr(train, "rasa_train", mock.method) # If the `Namespace` object does not have all required fields this will throw train.train(args) @@ -155,9 +156,6 @@ def test_no_interactive_without_core_data( def test_pass_conversation_id_to_interactive_learning(monkeypatch: MonkeyPatch): - from rasa.core.train import do_interactive_learning - from rasa.core.training import interactive as interactive_learning - parser = argparse.ArgumentParser() sub_parser = parser.add_subparsers() interactive.add_subparser(sub_parser, []) diff --git a/tests/core/actions/test_forms.py b/tests/core/actions/test_forms.py index a29958dd493a..a1ee67d4522e 100644 --- a/tests/core/actions/test_forms.py +++ b/tests/core/actions/test_forms.py @@ -436,7 +436,7 @@ def test_extract_requested_slot_default(): ], ) - slot_values = form.extract_requested_slot(tracker, Domain.empty()) + slot_values = form.extract_requested_slot(tracker, Domain.empty(), "some_slot") assert slot_values == {"some_slot": "some_value"} @@ -478,7 +478,7 @@ def test_extract_requested_slot_when_mapping_applies( ], ) - slot_values = form.extract_requested_slot(tracker, domain) + slot_values = form.extract_requested_slot(tracker, domain, "some_slot") # check that the value was extracted for correct intent assert slot_values == {"some_slot": expected_value} @@ -513,7 +513,7 @@ def test_extract_requested_slot_mapping_does_not_apply(slot_mapping: Dict): ], ) - slot_values = form.extract_requested_slot(tracker, domain) + slot_values = form.extract_requested_slot(tracker, domain, "some_slot") # check that the value was not extracted for incorrect intent assert slot_values == {} @@ -767,7 +767,7 @@ def test_extract_requested_slot_from_entity( ], ) - slot_values = form.extract_requested_slot(tracker, domain) + slot_values = form.extract_requested_slot(tracker, domain, "some_slot") assert slot_values == expected_slot_values @@ -784,7 +784,7 @@ def test_invalid_slot_mapping(): ) with pytest.raises(ValueError): - form.extract_requested_slot(tracker, domain) + form.extract_requested_slot(tracker, domain, slot_name) @pytest.mark.parametrize( From b4657adc0ba502764d563eeea2da1204a0bb1d27 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Tue, 1 Sep 2020 10:45:24 +0200 Subject: [PATCH 06/42] cleanup some type errors: [index] --- Makefile | 2 +- rasa/core/channels/rasa_chat.py | 9 +++++++-- rasa/core/domain.py | 17 ++++++++++++++--- rasa/core/featurizers.py | 10 ++++++++-- rasa/core/policies/ted_policy.py | 2 +- rasa/nlu/classifiers/diet_classifier.py | 4 +++- setup.cfg | 3 --- stubs/sanic/__init__.pyi | 5 +++++ stubs/sanic/exceptions.pyi | 3 +++ 9 files changed, 42 insertions(+), 13 deletions(-) create mode 100644 stubs/sanic/__init__.pyi create mode 100644 stubs/sanic/exceptions.pyi diff --git a/Makefile b/Makefile index c1245f329e71..5071e41d51e2 100644 --- a/Makefile +++ b/Makefile @@ -68,7 +68,7 @@ lint: poetry run black --check rasa tests types: - poetry run mypy rasa + MYPYPATH=./stubs poetry run mypy rasa prepare-tests-files: poetry install -E spacy diff --git a/rasa/core/channels/rasa_chat.py b/rasa/core/channels/rasa_chat.py index 12ce77aea68e..995378286e2b 100644 --- a/rasa/core/channels/rasa_chat.py +++ b/rasa/core/channels/rasa_chat.py @@ -5,6 +5,7 @@ import logging from sanic.exceptions import abort import jwt +import jwt.exceptions from rasa.core import constants from rasa.core.channels.channel import InputChannel @@ -35,7 +36,7 @@ def from_credentials(cls, credentials: Optional[Dict[Text, Any]]) -> InputChanne def __init__(self, url: Optional[Text]) -> None: self.base_url = url - self.jwt_key = None + self.jwt_key: Optional[Text] = None self.jwt_algorithm = None async def _fetch_public_key(self) -> None: @@ -71,6 +72,9 @@ async def _fetch_public_key(self) -> None: ) def _decode_jwt(self, bearer_token: Text) -> Dict: + if self.jwt_key is None: + raise TypeError("JWT public key is none.") + authorization_header_value = bearer_token.replace( constants.BEARER_TOKEN_PREFIX, "" ) @@ -85,12 +89,13 @@ async def _decode_bearer_token(self, bearer_token: Text) -> Optional[Dict]: # noinspection PyBroadException try: return self._decode_jwt(bearer_token) - except jwt.exceptions.InvalidSignatureError: + except jwt.InvalidSignatureError: logger.error("JWT public key invalid, fetching new one.") await self._fetch_public_key() return self._decode_jwt(bearer_token) except Exception: logger.exception("Failed to decode bearer token.") + return None async def _extract_sender(self, req: Request) -> Optional[Text]: """Fetch user from the Rasa X Admin API.""" diff --git a/rasa/core/domain.py b/rasa/core/domain.py index f355be4ef945..8fd28b4e9cd9 100644 --- a/rasa/core/domain.py +++ b/rasa/core/domain.py @@ -5,7 +5,18 @@ import os import typing from pathlib import Path -from typing import Any, Dict, List, NamedTuple, Optional, Set, Text, Tuple, Union +from typing import ( + Any, + Dict, + List, + NamedTuple, + NoReturn, + Optional, + Set, + Text, + Tuple, + Union, +) from ruamel.yaml import YAMLError @@ -562,7 +573,7 @@ def actions(self, action_endpoint) -> List[Optional[Action]]: self.action_for_name(name, action_endpoint) for name in self.action_names ] - def index_for_action(self, action_name: Text) -> Optional[int]: + def index_for_action(self, action_name: Text) -> Union[int, NoReturn]: """Look up which action index corresponds to this action name.""" try: @@ -570,7 +581,7 @@ def index_for_action(self, action_name: Text) -> Optional[int]: except ValueError: self._raise_action_not_found_exception(action_name) - def _raise_action_not_found_exception(self, action_name) -> typing.NoReturn: + def _raise_action_not_found_exception(self, action_name) -> NoReturn: action_names = "\n".join([f"\t - {a}" for a in self.action_names]) raise NameError( f"Cannot access action '{action_name}', " diff --git a/rasa/core/featurizers.py b/rasa/core/featurizers.py index 8cf8e401f609..8ac9f294a814 100644 --- a/rasa/core/featurizers.py +++ b/rasa/core/featurizers.py @@ -161,8 +161,8 @@ def __init__( self.slot_labels = [] self.bot_labels = [] - self.bot_vocab = None - self.user_vocab = None + self.bot_vocab: Optional[Dict[Text, int]] = None + self.user_vocab: Optional[Dict[Text, int]] = None @staticmethod def _create_label_token_dict(labels, split_symbol="_") -> Dict[Text, int]: @@ -250,6 +250,12 @@ def encode(self, state: Dict[Text, float]) -> np.ndarray: def create_encoded_all_actions(self, domain: Domain) -> np.ndarray: """Create matrix with all actions from domain encoded in rows as bag of words""" + if self.bot_vocab is None: + raise Exception( + "LabelTokenizerSingleStateFeaturizer " + "was not prepared before encoding." + ) + encoded_all_actions = np.zeros( (domain.num_actions, len(self.bot_vocab)), dtype=np.int32 ) diff --git a/rasa/core/policies/ted_policy.py b/rasa/core/policies/ted_policy.py index 963db1480258..7679c3c181e4 100644 --- a/rasa/core/policies/ted_policy.py +++ b/rasa/core/policies/ted_policy.py @@ -485,7 +485,7 @@ def __init__( # optimizer self.optimizer = tf.keras.optimizers.Adam() - self.all_labels_embed = None + self.all_labels_embed: Optional[tf.Tensor] = None label_batch = label_data.prepare_batch() self.tf_label_data = self.batch_to_model_data_format( diff --git a/rasa/nlu/classifiers/diet_classifier.py b/rasa/nlu/classifiers/diet_classifier.py index c7d66d764ce3..a35e201d9184 100644 --- a/rasa/nlu/classifiers/diet_classifier.py +++ b/rasa/nlu/classifiers/diet_classifier.py @@ -1186,7 +1186,9 @@ def __init__( self._create_metrics() self._update_metrics_to_log() - self.all_labels_embed = None # needed for efficient prediction + self.all_labels_embed: Optional[ + tf.Tensor + ] = None # needed for efficient prediction @staticmethod def _ordered_tag_specs( diff --git a/setup.cfg b/setup.cfg index 81ebaaab50fc..7829f8f193c4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -107,9 +107,6 @@ ignore_errors = True [mypy-rasa.core.channels.mattermost] ignore_errors = True -[mypy-rasa.core.channels.rasa_chat] -ignore_errors = True - [mypy-rasa.core.channels.rest] ignore_errors = True diff --git a/stubs/sanic/__init__.pyi b/stubs/sanic/__init__.pyi new file mode 100644 index 000000000000..d5861c3d9ec9 --- /dev/null +++ b/stubs/sanic/__init__.pyi @@ -0,0 +1,5 @@ +from sanic.__version__ import __version__ +from sanic.app import Sanic +from sanic.blueprints import Blueprint + +__all__ = ["Sanic", "Blueprint", "__version__"] diff --git a/stubs/sanic/exceptions.pyi b/stubs/sanic/exceptions.pyi new file mode 100644 index 000000000000..e7f461144d50 --- /dev/null +++ b/stubs/sanic/exceptions.pyi @@ -0,0 +1,3 @@ +from typing import NoReturn, Optional, Text + +def abort(status_code: int, message: Optional[Text] = None) -> NoReturn: ... From 2d5d621273a14e53f03b4639b1f69c6016806125 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Wed, 2 Sep 2020 14:30:34 +0200 Subject: [PATCH 07/42] clear errors in rasa.nlu.model --- rasa/core/agent.py | 6 ++---- rasa/nlu/components.py | 2 +- rasa/nlu/model.py | 21 ++++++++++----------- setup.cfg | 3 --- tests/nlu/test_components.py | 2 +- 5 files changed, 14 insertions(+), 20 deletions(-) diff --git a/rasa/core/agent.py b/rasa/core/agent.py index c41a530ef743..936e0f7f38ef 100644 --- a/rasa/core/agent.py +++ b/rasa/core/agent.py @@ -500,11 +500,9 @@ async def handle_message( "not supported anymore. Rather use `agent.handle_text(...)`." ) - def noop(_: Any) -> None: - logger.info("Ignoring message as there is no agent to handle it.") - if not self.is_ready(): - return noop(message) + logger.info("Ignoring message as there is no agent to handle it.") + return None processor = self.create_processor(message_preprocessor) diff --git a/rasa/nlu/components.py b/rasa/nlu/components.py index da438ec0770f..c2a008f37385 100644 --- a/rasa/nlu/components.py +++ b/rasa/nlu/components.py @@ -752,7 +752,7 @@ def create_component( try: component, cache_key = self.__get_cached_component( - component_config, Metadata(cfg.as_dict(), None) + component_config, Metadata(cfg.as_dict()) ) if component is None: component = registry.create_component_by_config(component_config, cfg) diff --git a/rasa/nlu/model.py b/rasa/nlu/model.py index a5957ec514f4..73dbae0d7f1f 100644 --- a/rasa/nlu/model.py +++ b/rasa/nlu/model.py @@ -63,17 +63,16 @@ def load(model_dir: Text): try: metadata_file = os.path.join(model_dir, "metadata.json") data = rasa.utils.io.read_json_file(metadata_file) - return Metadata(data, model_dir) + return Metadata(data) except Exception as e: abspath = os.path.abspath(os.path.join(model_dir, "metadata.json")) raise InvalidModelError( f"Failed to load model metadata from '{abspath}'. {e}" ) - def __init__(self, metadata: Dict[Text, Any], model_dir: Optional[Text]): + def __init__(self, metadata: Dict[Text, Any]): self.metadata = metadata - self.model_dir = model_dir def get(self, property_name: Text, default: Any = None) -> Any: return self.metadata.get(property_name, default) @@ -189,10 +188,8 @@ def train(self, data: TrainingData, **kwargs: Any) -> "Interpreter": for i, component in enumerate(self.pipeline): logger.info(f"Starting to train component {component.name}") component.prepare_partial_processing(self.pipeline[:i], context) - updates = component.train(working_data, self.config, **context) + component.train(working_data, self.config, **context) logger.info("Finished training component.") - if updates: - context.update(updates) return Interpreter(self.pipeline, context) @@ -237,7 +234,7 @@ def persist( metadata["pipeline"].append(component_meta) - Metadata(metadata, dir_name).persist(dir_name) + Metadata(metadata).persist(dir_name) if persistor is not None: persistor.persist(dir_name, model_name) @@ -299,17 +296,18 @@ def load( model_metadata = Metadata.load(model_dir) Interpreter.ensure_model_compatibility(model_metadata) - return Interpreter.create(model_metadata, component_builder, skip_validation) + return Interpreter.create(model_dir, model_metadata, component_builder, skip_validation) @staticmethod def create( + model_dir: Text, model_metadata: Metadata, component_builder: Optional[ComponentBuilder] = None, skip_validation: bool = False, ) -> "Interpreter": """Load stored model and components defined by the provided metadata.""" - context = {} + context: Dict[Text, Any] = {} if component_builder is None: # If no builder is passed, every interpreter creation will result @@ -326,7 +324,7 @@ def create( for i in range(model_metadata.number_of_components): component_meta = model_metadata.for_component(i) component = component_builder.load_component( - component_meta, model_metadata.model_dir, model_metadata, **context + component_meta, model_dir, model_metadata, **context ) try: updates = component.provide_context() @@ -371,7 +369,8 @@ def parse( output["text"] = "" return output - message = Message(text, self.default_output_attributes(), time=time) + timestamp = str(int(time.timestamp())) if time else None + message = Message(text, self.default_output_attributes(), time=timestamp) for component in self.pipeline: component.process(message, **self.context) diff --git a/setup.cfg b/setup.cfg index 7829f8f193c4..7f62dfcecff9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -329,9 +329,6 @@ ignore_errors = True [mypy-rasa.nlu.featurizers.sparse_featurizer.regex_featurizer] ignore_errors = True -[mypy-rasa.nlu.model] -ignore_errors = True - [mypy-rasa.nlu.persistor] ignore_errors = True diff --git a/tests/nlu/test_components.py b/tests/nlu/test_components.py index ec90bcfee2d8..b67107e3e058 100644 --- a/tests/nlu/test_components.py +++ b/tests/nlu/test_components.py @@ -78,7 +78,7 @@ def test_create_component_exception_messages( def test_builder_load_unknown(component_builder): with pytest.raises(Exception) as excinfo: component_meta = {"name": "my_made_up_componment"} - component_builder.load_component(component_meta, "", Metadata({}, None)) + component_builder.load_component(component_meta, "", Metadata({})) assert "Cannot find class" in str(excinfo.value) From 31f728c8f92e813dacf13a60f5f10ecbb959f4cf Mon Sep 17 00:00:00 2001 From: m-vdb Date: Mon, 12 Oct 2020 18:26:27 +0200 Subject: [PATCH 08/42] fix import issues in rasa.cli.train --- rasa/cli/train.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/rasa/cli/train.py b/rasa/cli/train.py index a713fee5d37f..c17457366923 100644 --- a/rasa/cli/train.py +++ b/rasa/cli/train.py @@ -7,6 +7,7 @@ import rasa.cli.arguments.train as train_arguments import rasa.cli.utils +import rasa.train from rasa.shared.utils.cli import print_error from rasa.shared.constants import ( CONFIG_MANDATORY_KEYS_CORE, @@ -94,7 +95,6 @@ def train(args: argparse.Namespace) -> Optional[Text]: def train_core( args: argparse.Namespace, train_path: Optional[Text] = None ) -> Optional[Text]: - from rasa.train import train_core output = train_path or args.out @@ -114,7 +114,7 @@ def train_core( config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS_CORE) - return train_core( + return rasa.train.train_core( domain=args.domain, config=config, stories=story_file, @@ -134,7 +134,6 @@ def train_core( def train_nlu( args: argparse.Namespace, train_path: Optional[Text] = None ) -> Optional[Text]: - from rasa.train import train_nlu output = train_path or args.out @@ -148,7 +147,7 @@ def train_nlu( args.domain, "domain", DEFAULT_DOMAIN_PATH, none_is_valid=True ) - return train_nlu( + return rasa.train.train_nlu( config=config, nlu_data=nlu_data, output=output, From 54b8b111201c478045714ee136fc57f23211446f Mon Sep 17 00:00:00 2001 From: m-vdb Date: Mon, 12 Oct 2020 18:26:35 +0200 Subject: [PATCH 09/42] disable some error codes --- Makefile | 3 --- 1 file changed, 3 deletions(-) diff --git a/Makefile b/Makefile index d0b941d9b346..1e66a1fb3474 100644 --- a/Makefile +++ b/Makefile @@ -77,12 +77,9 @@ types: --disable-error-code index \ --disable-error-code misc \ --disable-error-code return \ - --disable-error-code call-arg \ --disable-error-code type-var \ - --disable-error-code list-item \ --disable-error-code has-type \ --disable-error-code valid-type \ - --disable-error-code dict-item \ --disable-error-code no-redef \ --disable-error-code func-returns-value From c55264f825e7a2720db405806bd495d5dffc7863 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Thu, 15 Oct 2020 10:37:22 +0200 Subject: [PATCH 10/42] update extract_requested_slot() docstring --- rasa/core/actions/forms.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/rasa/core/actions/forms.py b/rasa/core/actions/forms.py index 81be9ad34397..9c9dd79d651a 100644 --- a/rasa/core/actions/forms.py +++ b/rasa/core/actions/forms.py @@ -317,8 +317,19 @@ def extract_other_slots( def extract_requested_slot( self, tracker: "DialogueStateTracker", domain: Domain, slot_to_fill: Text, ) -> Dict[Text, Any]: - """Extract the value of requested slot from a user input + """ + Extract the value of requested slot from a user input else return `None`. + + Args: + tracker: a DialogueStateTracker instance + domain: the current domain + slot_to_fill: the name of the slot to fill + + Returns: + a dictionary with one key being the name of the slot to fill + and its value being the slot value, or an empty dictionary + if no slot value was found. """ logger.debug(f"Trying to extract requested slot '{slot_to_fill}' ...") From 04b3ea45def5fb4aa66a23610f4f566ddee2c98e Mon Sep 17 00:00:00 2001 From: m-vdb Date: Thu, 15 Oct 2020 10:39:16 +0200 Subject: [PATCH 11/42] more precise exception message --- rasa/core/channels/rasa_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/channels/rasa_chat.py b/rasa/core/channels/rasa_chat.py index 1d1b2652d6e0..534966ac447d 100644 --- a/rasa/core/channels/rasa_chat.py +++ b/rasa/core/channels/rasa_chat.py @@ -73,7 +73,7 @@ async def _fetch_public_key(self) -> None: def _decode_jwt(self, bearer_token: Text) -> Dict: if self.jwt_key is None: - raise TypeError("JWT public key is none.") + raise TypeError("JWT public key is `None`.") authorization_header_value = bearer_token.replace( constants.BEARER_TOKEN_PREFIX, "" From 39a4d79c0e38d5ef5edafa1113383b210daacafb Mon Sep 17 00:00:00 2001 From: m-vdb Date: Thu, 15 Oct 2020 10:50:44 +0200 Subject: [PATCH 12/42] use HTTPStatus constants in rasa.server --- rasa/server.py | 107 +++++++++++++++++++++++++++++++------------------ 1 file changed, 68 insertions(+), 39 deletions(-) diff --git a/rasa/server.py b/rasa/server.py index 92820f235356..4bd29744dbfe 100644 --- a/rasa/server.py +++ b/rasa/server.py @@ -7,6 +7,7 @@ import traceback import typing from functools import reduce, wraps +from http import HTTPStatus from inspect import isawaitable from pathlib import Path from typing import Any, Callable, List, Optional, Text, Union, Dict, cast @@ -122,7 +123,7 @@ def decorated(*args, **kwargs): else app.agent.is_ready() ): raise ErrorResponse( - 409, + HTTPStatus.CONFLICT, "Conflict", "No agent loaded. To continue processing, a " "model of a trained agent needs to be loaded.", @@ -190,7 +191,7 @@ async def decorated( result = await result return result raise ErrorResponse( - 403, + HTTPStatus.FORBIDDEN, "NotAuthorized", "User has insufficient permissions.", help_url=_docs( @@ -204,7 +205,7 @@ async def decorated( result = await result return result raise ErrorResponse( - 401, + HTTPStatus.UNAUTHORIZED, "NotAuthenticated", "User is not authenticated.", help_url=_docs( @@ -229,7 +230,7 @@ def event_verbosity_parameter( except KeyError: enum_values = ", ".join([e.name for e in EventVerbosity]) raise ErrorResponse( - 400, + HTTPStatus.BAD_REQUEST, "BadRequest", "Invalid parameter value for 'include_events'. " "Should be one of {}".format(enum_values), @@ -253,7 +254,7 @@ def _validate_tracker( ) -> None: if not tracker: raise ErrorResponse( - 409, + HTTPStatus.CONFLICT, "Conflict", f"Could not retrieve tracker with ID '{conversation_id}'. Most likely " f"because there is no domain set on the agent.", @@ -263,7 +264,7 @@ def _validate_tracker( def validate_request_body(request: Request, error_message: Text): """Check if `request` has a body.""" if not request.body: - raise ErrorResponse(400, "BadRequest", error_message) + raise ErrorResponse(HTTPStatus.BAD_REQUEST, "BadRequest", error_message) async def authenticate(request: Request): @@ -328,7 +329,7 @@ def _create_emulator(mode: Optional[Text]) -> NoEmulator: return DialogflowEmulator() else: raise ErrorResponse( - 400, + HTTPStatus.BAD_REQUEST, "BadRequest", "Invalid parameter value for 'emulation_mode'. " "Should be one of 'WIT', 'LUIS', 'DIALOGFLOW'.", @@ -370,12 +371,14 @@ async def _load_agent( except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( - 500, "LoadingError", f"An unexpected error occurred. Error: {e}" + HTTPStatus.INTERNAL_SERVER_ERROR, + "LoadingError", + f"An unexpected error occurred. Error: {e}", ) if not loaded_agent: raise ErrorResponse( - 400, + HTTPStatus.BAD_REQUEST, "BadRequest", f"Agent with name '{model_path}' could not be loaded.", {"parameter": "model", "in": "query"}, @@ -495,7 +498,9 @@ async def retrieve_tracker(request: Request, conversation_id: Text): except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( - 500, "ConversationError", f"An unexpected error occurred. Error: {e}" + HTTPStatus.INTERNAL_SERVER_ERROR, + "ConversationError", + f"An unexpected error occurred. Error: {e}", ) @app.post("/conversations//tracker/events") @@ -534,7 +539,9 @@ async def append_events(request: Request, conversation_id: Text): except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( - 500, "ConversationError", f"An unexpected error occurred. Error: {e}" + HTTPStatus.INTERNAL_SERVER_ERROR, + "ConversationError", + f"An unexpected error occurred. Error: {e}", ) def _get_events_from_request_body(request: Request) -> List[Event]: @@ -552,7 +559,7 @@ def _get_events_from_request_body(request: Request) -> List[Event]: f"Request JSON: {request.json}" ) raise ErrorResponse( - 400, + HTTPStatus.BAD_REQUEST, "BadRequest", "Couldn't extract a proper event from the request body.", {"parameter": "", "in": "body"}, @@ -586,7 +593,9 @@ async def replace_events(request: Request, conversation_id: Text): except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( - 500, "ConversationError", f"An unexpected error occurred. Error: {e}" + HTTPStatus.INTERNAL_SERVER_ERROR, + "ConversationError", + f"An unexpected error occurred. Error: {e}", ) @app.get("/conversations//story") @@ -610,7 +619,9 @@ async def retrieve_story(request: Request, conversation_id: Text): except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( - 500, "ConversationError", f"An unexpected error occurred. Error: {e}" + HTTPStatus.INTERNAL_SERVER_ERROR, + "ConversationError", + f"An unexpected error occurred. Error: {e}", ) @app.post("/conversations//execute") @@ -623,7 +634,7 @@ async def execute_action(request: Request, conversation_id: Text): if not action_to_execute: raise ErrorResponse( - 400, + HTTPStatus.BAD_REQUEST, "BadRequest", "Name of the action not provided in request body.", {"parameter": "name", "in": "body"}, @@ -650,7 +661,9 @@ async def execute_action(request: Request, conversation_id: Text): except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( - 500, "ConversationError", f"An unexpected error occurred. Error: {e}" + HTTPStatus.INTERNAL_SERVER_ERROR, + "ConversationError", + f"An unexpected error occurred. Error: {e}", ) tracker = await get_tracker(app.agent.create_processor(), conversation_id) @@ -674,7 +687,7 @@ async def trigger_intent(request: Request, conversation_id: Text) -> HTTPRespons if not intent_to_trigger: raise ErrorResponse( - 400, + HTTPStatus.BAD_REQUEST, "BadRequest", "Name of the intent not provided in request body.", {"parameter": "name", "in": "body"}, @@ -690,7 +703,7 @@ async def trigger_intent(request: Request, conversation_id: Text) -> HTTPRespons output_channel = _get_output_channel(request, tracker) if intent_to_trigger not in app.agent.domain.intents: raise ErrorResponse( - 404, + HTTPStatus.NOT_FOUND, "NotFound", f"The intent {trigger_intent} does not exist in the domain.", ) @@ -705,7 +718,9 @@ async def trigger_intent(request: Request, conversation_id: Text) -> HTTPRespons except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( - 500, "ConversationError", f"An unexpected error occurred. Error: {e}" + HTTPStatus.INTERNAL_SERVER_ERROR, + "ConversationError", + f"An unexpected error occurred. Error: {e}", ) state = tracker.current_state(verbosity) @@ -731,7 +746,9 @@ async def predict(request: Request, conversation_id: Text) -> HTTPResponse: except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( - 500, "ConversationError", f"An unexpected error occurred. Error: {e}" + HTTPStatus.INTERNAL_SERVER_ERROR, + "ConversationError", + f"An unexpected error occurred. Error: {e}", ) @app.post("/conversations//messages") @@ -755,7 +772,7 @@ async def add_message(request: Request, conversation_id: Text): # TODO: implement for agent / bot if sender != "user": raise ErrorResponse( - 400, + HTTPStatus.BAD_REQUEST, "BadRequest", "Currently, only user messages can be passed to this endpoint. " "Messages of sender '{}' cannot be handled.".format(sender), @@ -771,7 +788,9 @@ async def add_message(request: Request, conversation_id: Text): except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( - 500, "ConversationError", f"An unexpected error occurred. Error: {e}" + HTTPStatus.INTERNAL_SERVER_ERROR, + "ConversationError", + f"An unexpected error occurred. Error: {e}", ) @app.post("/model/train") @@ -811,7 +830,7 @@ async def train(request: Request) -> HTTPResponse: ) else: raise ErrorResponse( - 500, + HTTPStatus.INTERNAL_SERVER_ERROR, "TrainingError", "Ran training, but it finished without a trained model.", ) @@ -819,14 +838,14 @@ async def train(request: Request) -> HTTPResponse: raise e except InvalidDomain as e: raise ErrorResponse( - 400, + HTTPStatus.BAD_REQUEST, "InvalidDomainError", f"Provided domain file is invalid. Error: {e}", ) except Exception as e: logger.error(traceback.format_exc()) raise ErrorResponse( - 500, + HTTPStatus.INTERNAL_SERVER_ERROR, "TrainingError", f"An unexpected error occurred during training. Error: {e}", ) @@ -857,7 +876,7 @@ async def evaluate_stories(request: Request) -> HTTPResponse: except Exception as e: logger.error(traceback.format_exc()) raise ErrorResponse( - 500, + HTTPStatus.INTERNAL_SERVER_ERROR, "TestingError", f"An unexpected error occurred during evaluation. Error: {e}", ) @@ -890,14 +909,18 @@ async def evaluate_intents(request: Request) -> HTTPResponse: if not eval_agent.model_directory or not os.path.exists( eval_agent.model_directory ): - raise ErrorResponse(409, "Conflict", "Loaded model file not found.") + raise ErrorResponse( + HTTPStatus.CONFLICT, "Conflict", "Loaded model file not found." + ) model_directory = eval_agent.model_directory _, nlu_model = model.get_model_subdirectories(model_directory) if nlu_model is None: raise ErrorResponse( - 500, "TestingError", "Missing NLU model directory.", + HTTPStatus.INTERNAL_SERVER_ERROR, + "TestingError", + "Missing NLU model directory.", ) try: @@ -906,7 +929,7 @@ async def evaluate_intents(request: Request) -> HTTPResponse: except Exception as e: logger.error(traceback.format_exc()) raise ErrorResponse( - 500, + HTTPStatus.INTERNAL_SERVER_ERROR, "TestingError", f"An unexpected error occurred during evaluation. Error: {e}", ) @@ -931,7 +954,7 @@ async def tracker_predict(request: Request) -> HTTPResponse: except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( - 400, + HTTPStatus.BAD_REQUEST, "BadRequest", f"Supplied events are not valid. {e}", {"parameter": "", "in": "body"}, @@ -958,7 +981,9 @@ async def tracker_predict(request: Request) -> HTTPResponse: except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( - 500, "PredictionError", f"An unexpected error occurred. Error: {e}" + HTTPStatus.INTERNAL_SERVER_ERROR, + "PredictionError", + f"An unexpected error occurred. Error: {e}", ) @app.post("/model/parse") @@ -982,7 +1007,9 @@ async def parse(request: Request) -> HTTPResponse: except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( - 400, "ParsingError", f"An unexpected error occurred. Error: {e}" + HTTPStatus.BAD_REQUEST, + "ParsingError", + f"An unexpected error occurred. Error: {e}", ) response_data = emulator.normalise_response_json(parsed_data) @@ -991,7 +1018,9 @@ async def parse(request: Request) -> HTTPResponse: except Exception as e: logger.debug(traceback.format_exc()) raise ErrorResponse( - 500, "ParsingError", f"An unexpected error occurred. Error: {e}" + HTTPStatus.INTERNAL_SERVER_ERROR, + "ParsingError", + f"An unexpected error occurred. Error: {e}", ) @app.put("/model") @@ -1009,7 +1038,7 @@ async def load_model(request: Request) -> HTTPResponse: except TypeError as e: logger.debug(traceback.format_exc()) raise ErrorResponse( - 400, + HTTPStatus.BAD_REQUEST, "BadRequest", f"Supplied 'model_server' is not valid. Error: {e}", {"parameter": "model_server", "in": "body"}, @@ -1049,7 +1078,7 @@ async def get_domain(request: Request) -> HTTPResponse: ) else: raise ErrorResponse( - 406, + HTTPStatus.NOT_ACCEPTABLE, "NotAcceptable", f"Invalid Accept header. Domain can be " f"provided as " @@ -1167,7 +1196,7 @@ def _training_payload_from_json(request: Request) -> Dict[Text, Union[Text, bool def _validate_json_training_payload(rjs: Dict): if "config" not in rjs: raise ErrorResponse( - 400, + HTTPStatus.BAD_REQUEST, "BadRequest", "The training request is missing the required key `config`.", {"parameter": "config", "in": "body"}, @@ -1175,7 +1204,7 @@ def _validate_json_training_payload(rjs: Dict): if "nlu" not in rjs and "stories" not in rjs: raise ErrorResponse( - 400, + HTTPStatus.BAD_REQUEST, "BadRequest", "To train a Rasa model you need to specify at least one type of " "training data. Add `nlu` and/or `stories` to the request.", @@ -1184,7 +1213,7 @@ def _validate_json_training_payload(rjs: Dict): if "stories" in rjs and "domain" not in rjs: raise ErrorResponse( - 400, + HTTPStatus.BAD_REQUEST, "BadRequest", "To train a Rasa model with story training data, you also need to " "specify the `domain`.", @@ -1235,7 +1264,7 @@ def _validate_yaml_training_payload(yaml_text: Text) -> None: RasaYAMLReader().validate(yaml_text) except Exception as e: raise ErrorResponse( - 400, + HTTPStatus.BAD_REQUEST, "BadRequest", f"The request body does not contain valid YAML. Error: {e}", help_url=DOCS_URL_TRAINING_DATA, From 33834e98a15dea0ebe74800e694637d866c3ba3c Mon Sep 17 00:00:00 2001 From: m-vdb Date: Thu, 15 Oct 2020 10:54:28 +0200 Subject: [PATCH 13/42] /model/test/intents returns the right status code in case of a conflict --- changelog/6511.misc.md | 3 +++ rasa/server.py | 4 +--- 2 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 changelog/6511.misc.md diff --git a/changelog/6511.misc.md b/changelog/6511.misc.md new file mode 100644 index 000000000000..008ff8814709 --- /dev/null +++ b/changelog/6511.misc.md @@ -0,0 +1,3 @@ +API endpoint `POST /model/test/intents` now returns HTTP 409 status +code in case it cannot find the NLU model directory, instead of an +HTTP 500 status. diff --git a/rasa/server.py b/rasa/server.py index 4bd29744dbfe..d84eb39cea78 100644 --- a/rasa/server.py +++ b/rasa/server.py @@ -918,9 +918,7 @@ async def evaluate_intents(request: Request) -> HTTPResponse: if nlu_model is None: raise ErrorResponse( - HTTPStatus.INTERNAL_SERVER_ERROR, - "TestingError", - "Missing NLU model directory.", + HTTPStatus.CONFLICT, "TestingError", "Missing NLU model directory.", ) try: From f8a817aa23d6d37df9ffcaa58e9d07402379121a Mon Sep 17 00:00:00 2001 From: m-vdb Date: Fri, 16 Oct 2020 15:16:02 +0200 Subject: [PATCH 14/42] rename methods to avoid obfuscation --- rasa/cli/interactive.py | 6 +++++- rasa/cli/train.py | 28 ++++++++++++---------------- tests/cli/test_rasa_interactive.py | 10 +++++----- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/rasa/cli/interactive.py b/rasa/cli/interactive.py index 30b70e886b9c..d2dfe962fa2b 100644 --- a/rasa/cli/interactive.py +++ b/rasa/cli/interactive.py @@ -65,7 +65,11 @@ def interactive(args: argparse.Namespace) -> None: "data or a model containing core data." ) - zipped_model = train.train_core(args) if args.core_only else train.train(args) + zipped_model = ( + train.run_core_training(args) + if args.core_only + else train.run_training(args) + ) if not zipped_model: rasa.shared.utils.cli.print_error_and_exit( "Could not train an initial model. Either pass paths " diff --git a/rasa/cli/train.py b/rasa/cli/train.py index 4e08e3a784cb..9fea96b67425 100644 --- a/rasa/cli/train.py +++ b/rasa/cli/train.py @@ -7,7 +7,8 @@ import rasa.cli.arguments.train as train_arguments import rasa.cli.utils -import rasa.train +import rasa.utils.common +from rasa.core.train import do_compare_training from rasa.shared.utils.cli import print_error from rasa.shared.constants import ( CONFIG_MANDATORY_KEYS_CORE, @@ -17,8 +18,7 @@ DEFAULT_DOMAIN_PATH, DEFAULT_DATA_PATH, ) - -import rasa.utils.common +from rasa.train import train, train_core, train_nlu def add_subparser( @@ -47,7 +47,7 @@ def add_subparser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, help="Trains a Rasa Core model using your stories.", ) - train_core_parser.set_defaults(func=train_core) + train_core_parser.set_defaults(func=run_core_training) train_nlu_parser = train_subparsers.add_parser( "nlu", @@ -55,17 +55,15 @@ def add_subparser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, help="Trains a Rasa NLU model using your NLU data.", ) - train_nlu_parser.set_defaults(func=train_nlu) + train_nlu_parser.set_defaults(func=run_nlu_training) - train_parser.set_defaults(func=train) + train_parser.set_defaults(func=run_training) train_arguments.set_train_core_arguments(train_core_parser) train_arguments.set_train_nlu_arguments(train_nlu_parser) -def train(args: argparse.Namespace) -> Optional[Text]: - import rasa - +def run_training(args: argparse.Namespace) -> Optional[Text]: domain = rasa.cli.utils.get_validated_path( args.domain, "domain", DEFAULT_DOMAIN_PATH, none_is_valid=True ) @@ -79,7 +77,7 @@ def train(args: argparse.Namespace) -> Optional[Text]: for f in args.data ] - return rasa.train( + return train( domain=domain, config=config, training_files=training_files, @@ -92,7 +90,7 @@ def train(args: argparse.Namespace) -> Optional[Text]: ) -def train_core( +def run_core_training( args: argparse.Namespace, train_path: Optional[Text] = None ) -> Optional[Text]: @@ -114,7 +112,7 @@ def train_core( config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS_CORE) - return rasa.train.train_core( + return train_core( domain=args.domain, config=config, stories=story_file, @@ -124,14 +122,12 @@ def train_core( additional_arguments=additional_arguments, ) else: - from rasa.core.train import do_compare_training - rasa.utils.common.run_in_loop( do_compare_training(args, story_file, additional_arguments) ) -def train_nlu( +def run_nlu_training( args: argparse.Namespace, train_path: Optional[Text] = None ) -> Optional[Text]: @@ -147,7 +143,7 @@ def train_nlu( args.domain, "domain", DEFAULT_DOMAIN_PATH, none_is_valid=True ) - return rasa.train.train_nlu( + return train_nlu( config=config, nlu_data=nlu_data, output=output, diff --git a/tests/cli/test_rasa_interactive.py b/tests/cli/test_rasa_interactive.py index f9b09bc7ed76..3f93b0d455d4 100644 --- a/tests/cli/test_rasa_interactive.py +++ b/tests/cli/test_rasa_interactive.py @@ -62,7 +62,7 @@ def test_pass_arguments_to_rasa_train( # Mock actual training mock = Mock() - monkeypatch.setattr(train, "rasa_train", mock.method) + monkeypatch.setattr(train, "train", mock.method) # If the `Namespace` object does not have all required fields this will throw train.train(args) @@ -91,7 +91,7 @@ def test_train_called_when_no_model_passed( # Mock actual training and interactive learning methods mock = Mock() - monkeypatch.setattr(train, "train", mock.train_model) + monkeypatch.setattr(train, "run_training", mock.train_model) monkeypatch.setattr( interactive, "perform_interactive_learning", mock.perform_interactive_learning ) @@ -123,13 +123,13 @@ def test_train_core_called_when_no_model_passed_and_core( # Mock actual training and interactive learning methods mock = Mock() - monkeypatch.setattr(train, "train_core", mock.train_core) + monkeypatch.setattr(train, "run_core_training", mock.run_core_training) monkeypatch.setattr( interactive, "perform_interactive_learning", mock.perform_interactive_learning ) interactive.interactive(args) - mock.train_core.assert_called_once() + mock.run_core_training.assert_called_once() def test_no_interactive_without_core_data( @@ -145,7 +145,7 @@ def test_no_interactive_without_core_data( interactive._set_not_required_args(args) mock = Mock() - monkeypatch.setattr(train, "train", mock.train_model) + monkeypatch.setattr(train, "run_training", mock.train_model) monkeypatch.setattr( interactive, "perform_interactive_learning", mock.perform_interactive_learning ) From 5eae1b492ca33dd0bda91d66f95f95ca68812f3f Mon Sep 17 00:00:00 2001 From: m-vdb Date: Fri, 16 Oct 2020 15:26:49 +0200 Subject: [PATCH 15/42] add type hint for tracker.latest_message --- rasa/shared/core/trackers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/shared/core/trackers.py b/rasa/shared/core/trackers.py index 3bffc2fc3a2d..fdb459dd1c82 100644 --- a/rasa/shared/core/trackers.py +++ b/rasa/shared/core/trackers.py @@ -180,7 +180,7 @@ def __init__( self.followup_action = ACTION_LISTEN_NAME self.latest_action = None # Stores the most recent message sent by the user - self.latest_message = None + self.latest_message: Optional[Event] = None self.latest_bot_utterance = None self._reset() self.active_loop: Dict[Text, Union[Text, bool, Dict, None]] = {} From 2c676685ea772abbc9b23710571810b625c1919e Mon Sep 17 00:00:00 2001 From: m-vdb Date: Mon, 19 Oct 2020 17:09:53 +0200 Subject: [PATCH 16/42] add precise type definition for DialogueStateTracker.active_loop --- rasa/shared/core/trackers.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/rasa/shared/core/trackers.py b/rasa/shared/core/trackers.py index eddc74ff8d8d..4da885266b3c 100644 --- a/rasa/shared/core/trackers.py +++ b/rasa/shared/core/trackers.py @@ -17,6 +17,7 @@ Union, FrozenSet, Tuple, + TypedDict, TYPE_CHECKING, ) @@ -69,6 +70,18 @@ # same as State but with Dict[...] substituted with FrozenSet[Tuple[...]] FrozenState = FrozenSet[Tuple[Text, FrozenSet[Tuple[Text, Tuple[Union[float, Text]]]]]] +# precise type definition for `DialogueStateTracker.active_loop` +TrackerActiveLoop = TypedDict( + "TrackerActiveLoop", + { + LOOP_NAME: Text, + LOOP_INTERRUPTED: bool, + LOOP_REJECTED: bool, + TRIGGER_MESSAGE: Dict, + }, + total=False, +) + class EventVerbosity(Enum): """Filter on which events to include in tracker dumps.""" @@ -185,7 +198,7 @@ def __init__( self.latest_message: Optional[Event] = None self.latest_bot_utterance = None self._reset() - self.active_loop: Dict[Text, Union[Text, bool, Dict, None]] = {} + self.active_loop: TrackerActiveLoop = {} ### # Public tracker interface From 1b9e44c1215769b9e7552f7382d347ed04f7f13c Mon Sep 17 00:00:00 2001 From: m-vdb Date: Wed, 11 Nov 2020 16:17:43 +0100 Subject: [PATCH 17/42] add FIXMEs for type issues for now --- rasa/core/featurizers/single_state_featurizer.py | 7 +++++-- rasa/core/test.py | 5 +++-- rasa/model.py | 2 +- rasa/shared/core/events.py | 6 ++++-- rasa/shared/core/training_data/structures.py | 4 +++- 5 files changed, 16 insertions(+), 8 deletions(-) diff --git a/rasa/core/featurizers/single_state_featurizer.py b/rasa/core/featurizers/single_state_featurizer.py index 3b86d4962cd3..3a7ef23b2691 100644 --- a/rasa/core/featurizers/single_state_featurizer.py +++ b/rasa/core/featurizers/single_state_featurizer.py @@ -59,12 +59,15 @@ def convert_to_dict(feature_states: List[Text]) -> Dict[Text, int]: def _state_features_for_attribute( self, sub_state: SubState, attribute: Text ) -> Dict[Text, int]: + # FIXME: the code below is not type-safe, but fixing it + # would require more refactoring, for instance using + # data classes in our states if attribute in {INTENT, ACTION_NAME}: - return {sub_state[attribute]: 1} + return {sub_state[attribute]: 1} # type: ignore[dict-item] elif attribute == ENTITIES: return {entity: 1 for entity in sub_state.get(ENTITIES, [])} elif attribute == ACTIVE_LOOP: - return {sub_state["name"]: 1} + return {sub_state["name"]: 1} # type: ignore[dict-item] elif attribute == SLOTS: return { f"{slot_name}_{i}": value diff --git a/rasa/core/test.py b/rasa/core/test.py index 02a09ae78dfc..4464b18edbd2 100644 --- a/rasa/core/test.py +++ b/rasa/core/test.py @@ -3,7 +3,7 @@ import warnings import typing from collections import defaultdict, namedtuple -from typing import Any, Dict, List, Optional, Text, Tuple +from typing import Any, Dict, List, Optional, Text, Tuple, cast from rasa import telemetry from rasa.core.policies.policy import PolicyPrediction @@ -422,7 +422,8 @@ def _collect_action_executed_predictions( action_executed_eval_store = EvaluationStore() - gold = event.action_name or event.action_text + # FIXME: mypy doesn't pick up typing guard in `ActionExecuted.__init__` + gold = cast(Text, event.action_name or event.action_text) if circuit_breaker_tripped: prediction = PolicyPrediction([], policy_name=None) diff --git a/rasa/model.py b/rasa/model.py index 69cdd6129dfc..d5b6a03e76ca 100644 --- a/rasa/model.py +++ b/rasa/model.py @@ -36,7 +36,7 @@ # Type alias for the fingerprint -Fingerprint = Dict[Text, Union[Text, List[Text], int, float]] +Fingerprint = Dict[Text, Union[Optional[Text], List[Text], int, float]] FINGERPRINT_FILE_PATH = "fingerprint.json" diff --git a/rasa/shared/core/events.py b/rasa/shared/core/events.py index 076261e0edc2..42718b007ee9 100644 --- a/rasa/shared/core/events.py +++ b/rasa/shared/core/events.py @@ -7,7 +7,7 @@ import uuid from dateutil import parser from datetime import datetime -from typing import List, Dict, Text, Any, Type, Optional, TYPE_CHECKING, Iterable +from typing import List, Dict, Text, Any, Type, Optional, TYPE_CHECKING, Iterable, cast import rasa.shared.utils.common from typing import Union @@ -1210,7 +1210,9 @@ def as_sub_state(self) -> Dict[Text, Text]: if self.action_name: return {ACTION_NAME: self.action_name} else: - return {ACTION_TEXT: self.action_text} + # FIXME: we should define the type better here, and require either + # `action_name` or `action_text` + return {ACTION_TEXT: cast(Text, self.action_text)} def apply_to(self, tracker: "DialogueStateTracker") -> None: tracker.set_latest_action(self.as_sub_state()) diff --git a/rasa/shared/core/training_data/structures.py b/rasa/shared/core/training_data/structures.py index e5cfdb4f39d1..172c708a1f9f 100644 --- a/rasa/shared/core/training_data/structures.py +++ b/rasa/shared/core/training_data/structures.py @@ -149,7 +149,9 @@ def _or_string(story_step_element: List[Event], e2e: bool) -> Text: ) result = " OR ".join( - [element.as_story_string(e2e) for element in story_step_element] + # FIXME: this breaks below because not + # all `as_story_string()` take a `e2e` argument. + [element.as_story_string(e2e) for element in story_step_element] # type: ignore[call-arg] ) return f"* {result}\n" From cfe194f0bcef153605634c29c15fc06129317690 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Wed, 11 Nov 2020 16:35:38 +0100 Subject: [PATCH 18/42] fix black formatting --- rasa/shared/core/training_data/structures.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/rasa/shared/core/training_data/structures.py b/rasa/shared/core/training_data/structures.py index 172c708a1f9f..5405569d3c8d 100644 --- a/rasa/shared/core/training_data/structures.py +++ b/rasa/shared/core/training_data/structures.py @@ -150,8 +150,11 @@ def _or_string(story_step_element: List[Event], e2e: bool) -> Text: result = " OR ".join( # FIXME: this breaks below because not - # all `as_story_string()` take a `e2e` argument. - [element.as_story_string(e2e) for element in story_step_element] # type: ignore[call-arg] + # all `as_story_string()` take a `e2e` argument. + [ + element.as_story_string(e2e) # type: ignore[call-arg] + for element in story_step_element + ] ) return f"* {result}\n" From ed9f7d38644ae15c1b3592db99bf3421583a275c Mon Sep 17 00:00:00 2001 From: m-vdb Date: Wed, 11 Nov 2020 16:37:14 +0100 Subject: [PATCH 19/42] fix TypedDict import --- rasa/shared/core/trackers.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/rasa/shared/core/trackers.py b/rasa/shared/core/trackers.py index d130097cb849..126f99fca489 100644 --- a/rasa/shared/core/trackers.py +++ b/rasa/shared/core/trackers.py @@ -63,27 +63,29 @@ from rasa.shared.core.slots import Slot if TYPE_CHECKING: + from typing_extension import TypedDict + from rasa.shared.core.training_data.structures import Story from rasa.shared.core.training_data.story_writer.story_writer import StoryWriter + # precise type definition for `DialogueStateTracker.active_loop` + TrackerActiveLoop = TypedDict( + "TrackerActiveLoop", + { + LOOP_NAME: Text, + LOOP_INTERRUPTED: bool, + LOOP_REJECTED: bool, + TRIGGER_MESSAGE: Dict, + }, + total=False, + ) + logger = logging.getLogger(__name__) # same as State but with Dict[...] substituted with FrozenSet[Tuple[...]] FrozenState = FrozenSet[Tuple[Text, FrozenSet[Tuple[Text, Tuple[Union[float, Text]]]]]] -# precise type definition for `DialogueStateTracker.active_loop` -TrackerActiveLoop = TypedDict( - "TrackerActiveLoop", - { - LOOP_NAME: Text, - LOOP_INTERRUPTED: bool, - LOOP_REJECTED: bool, - TRIGGER_MESSAGE: Dict, - }, - total=False, -) - class EventVerbosity(Enum): """Filter on which events to include in tracker dumps.""" @@ -200,7 +202,7 @@ def __init__( self.latest_message: Optional[Event] = None self.latest_bot_utterance = None self._reset() - self.active_loop: TrackerActiveLoop = {} + self.active_loop: "TrackerActiveLoop" = {} ### # Public tracker interface From 297c1dcca4aad000a91a5a811e385b479d489a55 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Wed, 18 Nov 2020 18:54:59 +0100 Subject: [PATCH 20/42] simpler type annotation for SanicView --- rasa/server.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/rasa/server.py b/rasa/server.py index ce2fbe0b8ca4..fae1aa178def 100644 --- a/rasa/server.py +++ b/rasa/server.py @@ -68,14 +68,10 @@ if TYPE_CHECKING: from ssl import SSLContext - from typing_extensions import Protocol from rasa.core.processor import MessageProcessor + from mypy_extensions import VarArg, KwArg - class SanicView(Protocol): - def __call__( - self, request: Request, *args: Any, **kwargs: Any - ) -> response.BaseHTTPResponse: - ... + SanicView = Callable[[Request, VarArg(), KwArg()], response.BaseHTTPResponse] logger = logging.getLogger(__name__) @@ -147,7 +143,9 @@ def decorated(*args, **kwargs): return decorator -def requires_auth(app: Sanic, token: Optional[Text] = None) -> Callable[[Any], Any]: +def requires_auth( + app: Sanic, token: Optional[Text] = None +) -> Callable[["SanicView"], "SanicView"]: """Wraps a request handler with token authentication.""" def decorator(f: "SanicView") -> "SanicView": From 89b5c6a8199acdc6e99e59a413117fabba19bf62 Mon Sep 17 00:00:00 2001 From: Maxime Vdb Date: Wed, 18 Nov 2020 18:58:42 +0100 Subject: [PATCH 21/42] Better error message when raising TypeError for missing JWT Co-authored-by: Tobias Wochinger --- rasa/core/channels/rasa_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/core/channels/rasa_chat.py b/rasa/core/channels/rasa_chat.py index aa904ca9dc3d..d5ce737ebf74 100644 --- a/rasa/core/channels/rasa_chat.py +++ b/rasa/core/channels/rasa_chat.py @@ -73,7 +73,7 @@ async def _fetch_public_key(self) -> None: def _decode_jwt(self, bearer_token: Text) -> Dict: if self.jwt_key is None: - raise TypeError("JWT public key is `None`.") + raise TypeError("JWT public key is `None`. This is likely caused by an error when retrieving the public key from Rasa X.") authorization_header_value = bearer_token.replace( constants.BEARER_TOKEN_PREFIX, "" From f3c2bfb425b9b7d528e7d570c8b5905a4d876857 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Wed, 18 Nov 2020 19:12:35 +0100 Subject: [PATCH 22/42] raise exception when failing to decode JWT key in Rasa chat --- rasa/core/channels/rasa_chat.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/rasa/core/channels/rasa_chat.py b/rasa/core/channels/rasa_chat.py index d5ce737ebf74..888aa822408b 100644 --- a/rasa/core/channels/rasa_chat.py +++ b/rasa/core/channels/rasa_chat.py @@ -73,7 +73,10 @@ async def _fetch_public_key(self) -> None: def _decode_jwt(self, bearer_token: Text) -> Dict: if self.jwt_key is None: - raise TypeError("JWT public key is `None`. This is likely caused by an error when retrieving the public key from Rasa X.") + raise TypeError( + "JWT public key is `None`. This is likely caused " + "by an error when retrieving the public key from Rasa X." + ) authorization_header_value = bearer_token.replace( constants.BEARER_TOKEN_PREFIX, "" @@ -86,16 +89,12 @@ async def _decode_bearer_token(self, bearer_token: Text) -> Optional[Dict]: if self.jwt_key is None: await self._fetch_public_key() - # noinspection PyBroadException try: return self._decode_jwt(bearer_token) except jwt.InvalidSignatureError: logger.error("JWT public key invalid, fetching new one.") await self._fetch_public_key() return self._decode_jwt(bearer_token) - except Exception: - logger.exception("Failed to decode bearer token.") - return None async def _extract_sender(self, req: Request) -> Optional[Text]: """Fetch user from the Rasa X Admin API.""" From eb897f6b8e35626e98c3c6e5d9a4ddd892e94182 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Wed, 18 Nov 2020 19:16:32 +0100 Subject: [PATCH 23/42] change type of Message.time to int --- rasa/nlu/extractors/duckling_entity_extractor.py | 2 +- rasa/nlu/model.py | 2 +- rasa/shared/nlu/training_data/message.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/rasa/nlu/extractors/duckling_entity_extractor.py b/rasa/nlu/extractors/duckling_entity_extractor.py index 4216d1796275..35c960a15ed3 100644 --- a/rasa/nlu/extractors/duckling_entity_extractor.py +++ b/rasa/nlu/extractors/duckling_entity_extractor.py @@ -162,7 +162,7 @@ def _duckling_parse(self, text: Text, reference_time: int) -> List[Dict[Text, An def _reference_time_from_message(message: Message) -> int: if message.time is not None: try: - return int(message.time) * 1000 + return message.time * 1000 except ValueError as e: logging.warning( "Could not parse timestamp {}. Instead " diff --git a/rasa/nlu/model.py b/rasa/nlu/model.py index 2b018deea5af..b238e663b935 100644 --- a/rasa/nlu/model.py +++ b/rasa/nlu/model.py @@ -387,7 +387,7 @@ def parse( output["text"] = "" return output - timestamp = str(int(time.timestamp())) if time else None + timestamp = int(time.timestamp()) if time else None data = self.default_output_attributes() data[TEXT] = text diff --git a/rasa/shared/nlu/training_data/message.py b/rasa/shared/nlu/training_data/message.py index 6e1fb8569c52..f61c65b0a508 100644 --- a/rasa/shared/nlu/training_data/message.py +++ b/rasa/shared/nlu/training_data/message.py @@ -30,7 +30,7 @@ def __init__( self, data: Optional[Dict[Text, Any]] = None, output_properties: Optional[Set] = None, - time: Optional[Text] = None, + time: Optional[int] = None, features: Optional[List["Features"]] = None, **kwargs: Any, ) -> None: From 0226600c68b865eef550361c306999bf0b8ff607 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Thu, 19 Nov 2020 10:59:37 +0100 Subject: [PATCH 24/42] remove useless import --- rasa/shared/core/trackers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rasa/shared/core/trackers.py b/rasa/shared/core/trackers.py index 126f99fca489..1e52d3568c43 100644 --- a/rasa/shared/core/trackers.py +++ b/rasa/shared/core/trackers.py @@ -18,7 +18,6 @@ Union, FrozenSet, Tuple, - TypedDict, TYPE_CHECKING, ) From 210d3a9dad4f968a7bcb82b615c9138704807f1d Mon Sep 17 00:00:00 2001 From: m-vdb Date: Thu, 19 Nov 2020 11:01:07 +0100 Subject: [PATCH 25/42] more precise type for DialogStateTracker.latest_message --- rasa/shared/core/trackers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/shared/core/trackers.py b/rasa/shared/core/trackers.py index 1e52d3568c43..8c67680fc3a4 100644 --- a/rasa/shared/core/trackers.py +++ b/rasa/shared/core/trackers.py @@ -198,7 +198,7 @@ def __init__( self.followup_action = ACTION_LISTEN_NAME self.latest_action = None # Stores the most recent message sent by the user - self.latest_message: Optional[Event] = None + self.latest_message: Optional[UserUttered] = None self.latest_bot_utterance = None self._reset() self.active_loop: "TrackerActiveLoop" = {} From ab748f6307abae4df86813948d4a91e314536ce5 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Thu, 19 Nov 2020 11:22:01 +0100 Subject: [PATCH 26/42] explicit cast() in StoryStep._or_string() --- rasa/shared/core/training_data/structures.py | 60 ++++++++++++-------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/rasa/shared/core/training_data/structures.py b/rasa/shared/core/training_data/structures.py index 5405569d3c8d..4637bd8cfbab 100644 --- a/rasa/shared/core/training_data/structures.py +++ b/rasa/shared/core/training_data/structures.py @@ -4,7 +4,19 @@ import uuid import typing -from typing import List, Text, Dict, Optional, Tuple, Any, Set, ValuesView, Union +from typing import ( + List, + Text, + Dict, + Optional, + Tuple, + Any, + Set, + ValuesView, + Union, + Sequence, + cast, +) import rasa.shared.utils.io from rasa.shared.core.constants import ACTION_LISTEN_NAME, ACTION_SESSION_START_NAME @@ -141,20 +153,18 @@ def _bot_string(story_step_element: Event) -> Text: return f" - {story_step_element.as_story_string()}\n" @staticmethod - def _or_string(story_step_element: List[Event], e2e: bool) -> Text: + def _or_string(story_step_element: Sequence[Event], e2e: bool) -> Text: for event in story_step_element: if not isinstance(event, UserUttered): raise EventTypeError( "OR statement events must be of type `UserUttered`." ) + # FIXME: https://github.com/python/mypy/issues/7853 + story_step_element = cast(Sequence[UserUttered], story_step_element) + result = " OR ".join( - # FIXME: this breaks below because not - # all `as_story_string()` take a `e2e` argument. - [ - element.as_story_string(e2e) # type: ignore[call-arg] - for element in story_step_element - ] + [element.as_story_string(e2e) for element in story_step_element] ) return f"* {result}\n" @@ -165,35 +175,35 @@ def as_story_string(self, flat: bool = False, e2e: bool = False) -> Text: result = "" else: result = f"\n## {self.block_name}\n" - for s in self.start_checkpoints: - if s.name != STORY_START: - result += self._checkpoint_string(s) + for checkpoint in self.start_checkpoints: + if checkpoint.name != STORY_START: + result += self._checkpoint_string(checkpoint) - for s in self.events: + for event in self.events: if ( - self.is_action_listen(s) - or self.is_action_session_start(s) - or isinstance(s, SessionStarted) + self.is_action_listen(event) + or self.is_action_session_start(event) + or isinstance(event, SessionStarted) ): continue - if isinstance(s, UserUttered): - result += self._user_string(s, e2e) - elif isinstance(s, Event): - converted = s.as_story_string() + if isinstance(event, UserUttered): + result += self._user_string(event, e2e) + elif isinstance(event, Event): + converted = event.as_story_string() if converted: - result += self._bot_string(s) - elif isinstance(s, list): + result += self._bot_string(event) + elif isinstance(event, list): # The story reader classes support reading stories in # conversion mode. When this mode is enabled, OR statements # are represented as lists of events. - result += self._or_string(s, e2e) + result += self._or_string(event, e2e) else: - raise Exception(f"Unexpected element in story step: {s}") + raise Exception(f"Unexpected element in story step: {event}") if not flat: - for s in self.end_checkpoints: - result += self._checkpoint_string(s) + for checkpoint in self.end_checkpoints: + result += self._checkpoint_string(checkpoint) return result @staticmethod From de5b4286a668f2122795e84c33a345d50e4db9cb Mon Sep 17 00:00:00 2001 From: m-vdb Date: Wed, 10 Mar 2021 08:57:47 +0100 Subject: [PATCH 27/42] fix merge issue in pyproject.toml --- poetry.lock | 44 +++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 1 + 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 0aadd2a15b09..c0a1144f84eb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3138,7 +3138,7 @@ transformers = ["transformers"] [metadata] lock-version = "1.1" python-versions = ">=3.6,<3.9" -content-hash = "944cb8d58cb966c982422593c6e806650b14e8d52f444aab4063af0dcfdb2eb4" +content-hash = "b3ae92ed96ebf4bb6bf085337a0bdbaf2e7220fbdf3aab42b4e728fadd5ca150" [metadata.files] absl-py = [ @@ -3870,20 +3870,39 @@ markupsafe = [ {file = "MarkupSafe-1.1.1-cp35-cp35m-win32.whl", hash = "sha256:6dd73240d2af64df90aa7c4e7481e23825ea70af4b4922f8ede5b9e35f78a3b1"}, {file = "MarkupSafe-1.1.1-cp35-cp35m-win_amd64.whl", hash = "sha256:9add70b36c5666a2ed02b43b335fe19002ee5235efd4b8a89bfcf9005bebac0d"}, {file = "MarkupSafe-1.1.1-cp36-cp36m-macosx_10_6_intel.whl", hash = "sha256:24982cc2533820871eba85ba648cd53d8623687ff11cbb805be4ff7b4c971aff"}, + {file = "MarkupSafe-1.1.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:d53bc011414228441014aa71dbec320c66468c1030aae3a6e29778a3382d96e5"}, {file = "MarkupSafe-1.1.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:00bc623926325b26bb9605ae9eae8a215691f33cae5df11ca5424f06f2d1f473"}, {file = "MarkupSafe-1.1.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:717ba8fe3ae9cc0006d7c451f0bb265ee07739daf76355d06366154ee68d221e"}, + {file = "MarkupSafe-1.1.1-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:3b8a6499709d29c2e2399569d96719a1b21dcd94410a586a18526b143ec8470f"}, + {file = "MarkupSafe-1.1.1-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:84dee80c15f1b560d55bcfe6d47b27d070b4681c699c572af2e3c7cc90a3b8e0"}, + {file = "MarkupSafe-1.1.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:b1dba4527182c95a0db8b6060cc98ac49b9e2f5e64320e2b56e47cb2831978c7"}, {file = "MarkupSafe-1.1.1-cp36-cp36m-win32.whl", hash = "sha256:535f6fc4d397c1563d08b88e485c3496cf5784e927af890fb3c3aac7f933ec66"}, {file = "MarkupSafe-1.1.1-cp36-cp36m-win_amd64.whl", hash = "sha256:b1282f8c00509d99fef04d8ba936b156d419be841854fe901d8ae224c59f0be5"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-macosx_10_6_intel.whl", hash = "sha256:8defac2f2ccd6805ebf65f5eeb132adcf2ab57aa11fdf4c0dd5169a004710e7d"}, + {file = "MarkupSafe-1.1.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:bf5aa3cbcfdf57fa2ee9cd1822c862ef23037f5c832ad09cfea57fa846dec193"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:46c99d2de99945ec5cb54f23c8cd5689f6d7177305ebff350a58ce5f8de1669e"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6"}, + {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:6fffc775d90dcc9aed1b89219549b329a9250d918fd0b8fa8d93d154918422e1"}, + {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:a6a744282b7718a2a62d2ed9d993cad6f5f585605ad352c11de459f4108df0a1"}, + {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:195d7d2c4fbb0ee8139a6cf67194f3973a6b3042d742ebe0a9ed36d8b6f0c07f"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win32.whl", hash = "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c"}, {file = "MarkupSafe-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6788b695d50a51edb699cb55e35487e430fa21f1ed838122d722e0ff0ac5ba15"}, {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:cdb132fc825c38e1aeec2c8aa9338310d29d337bebbd7baa06889d09a60a1fa2"}, {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:13d3144e1e340870b25e7b10b98d779608c02016d5184cfb9927a9f10c689f42"}, + {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:acf08ac40292838b3cbbb06cfe9b2cb9ec78fce8baca31ddb87aaac2e2dc3bc2"}, + {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:d9be0ba6c527163cbed5e0857c451fcd092ce83947944d6c14bc95441203f032"}, + {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:caabedc8323f1e93231b52fc32bdcde6db817623d33e100708d9a68e1f53b26b"}, {file = "MarkupSafe-1.1.1-cp38-cp38-win32.whl", hash = "sha256:596510de112c685489095da617b5bcbbac7dd6384aeebeda4df6025d0256a81b"}, {file = "MarkupSafe-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be"}, + {file = "MarkupSafe-1.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d73a845f227b0bfe8a7455ee623525ee656a9e2e749e4742706d80a6065d5e2c"}, + {file = "MarkupSafe-1.1.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:98bae9582248d6cf62321dcb52aaf5d9adf0bad3b40582925ef7c7f0ed85fceb"}, + {file = "MarkupSafe-1.1.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:2beec1e0de6924ea551859edb9e7679da6e4870d32cb766240ce17e0a0ba2014"}, + {file = "MarkupSafe-1.1.1-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:7fed13866cf14bba33e7176717346713881f56d9d2bcebab207f7a036f41b850"}, + {file = "MarkupSafe-1.1.1-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:6f1e273a344928347c1290119b493a1f0303c52f5a5eae5f16d74f48c15d4a85"}, + {file = "MarkupSafe-1.1.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:feb7b34d6325451ef96bc0e36e1a6c0c1c64bc1fbec4b854f4529e51887b1621"}, + {file = "MarkupSafe-1.1.1-cp39-cp39-win32.whl", hash = "sha256:22c178a091fc6630d0d045bdb5992d2dfe14e3259760e713c490da5323866c39"}, + {file = "MarkupSafe-1.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:b7d644ddb4dbd407d31ffb699f1d140bc35478da613b441c582aeb7c43838dd8"}, {file = "MarkupSafe-1.1.1.tar.gz", hash = "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b"}, ] matplotlib = [ @@ -4469,18 +4488,26 @@ pyyaml = [ {file = "PyYAML-5.4.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:bb4191dfc9306777bc594117aee052446b3fa88737cd13b7188d0e7aa8162185"}, {file = "PyYAML-5.4.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:6c78645d400265a062508ae399b60b8c167bf003db364ecb26dcab2bda048253"}, {file = "PyYAML-5.4.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:4e0583d24c881e14342eaf4ec5fbc97f934b999a6828693a99157fde912540cc"}, + {file = "PyYAML-5.4.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:72a01f726a9c7851ca9bfad6fd09ca4e090a023c00945ea05ba1638c09dc3347"}, + {file = "PyYAML-5.4.1-cp36-cp36m-manylinux2014_s390x.whl", hash = "sha256:895f61ef02e8fed38159bb70f7e100e00f471eae2bc838cd0f4ebb21e28f8541"}, {file = "PyYAML-5.4.1-cp36-cp36m-win32.whl", hash = "sha256:3bd0e463264cf257d1ffd2e40223b197271046d09dadf73a0fe82b9c1fc385a5"}, {file = "PyYAML-5.4.1-cp36-cp36m-win_amd64.whl", hash = "sha256:e4fac90784481d221a8e4b1162afa7c47ed953be40d31ab4629ae917510051df"}, {file = "PyYAML-5.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5accb17103e43963b80e6f837831f38d314a0495500067cb25afab2e8d7a4018"}, {file = "PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e1d4970ea66be07ae37a3c2e48b5ec63f7ba6804bdddfdbd3cfd954d25a82e63"}, + {file = "PyYAML-5.4.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:cb333c16912324fd5f769fff6bc5de372e9e7a202247b48870bc251ed40239aa"}, + {file = "PyYAML-5.4.1-cp37-cp37m-manylinux2014_s390x.whl", hash = "sha256:fe69978f3f768926cfa37b867e3843918e012cf83f680806599ddce33c2c68b0"}, {file = "PyYAML-5.4.1-cp37-cp37m-win32.whl", hash = "sha256:dd5de0646207f053eb0d6c74ae45ba98c3395a571a2891858e87df7c9b9bd51b"}, {file = "PyYAML-5.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:08682f6b72c722394747bddaf0aa62277e02557c0fd1c42cb853016a38f8dedf"}, {file = "PyYAML-5.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d2d9808ea7b4af864f35ea216be506ecec180628aced0704e34aca0b040ffe46"}, {file = "PyYAML-5.4.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:8c1be557ee92a20f184922c7b6424e8ab6691788e6d86137c5d93c1a6ec1b8fb"}, + {file = "PyYAML-5.4.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:fd7f6999a8070df521b6384004ef42833b9bd62cfee11a09bda1079b4b704247"}, + {file = "PyYAML-5.4.1-cp38-cp38-manylinux2014_s390x.whl", hash = "sha256:bfb51918d4ff3d77c1c856a9699f8492c612cde32fd3bcd344af9be34999bfdc"}, {file = "PyYAML-5.4.1-cp38-cp38-win32.whl", hash = "sha256:fa5ae20527d8e831e8230cbffd9f8fe952815b2b7dae6ffec25318803a7528fc"}, {file = "PyYAML-5.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:0f5f5786c0e09baddcd8b4b45f20a7b5d61a7e7e99846e3c799b05c7c53fa696"}, {file = "PyYAML-5.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:294db365efa064d00b8d1ef65d8ea2c3426ac366c0c4368d930bf1c5fb497f77"}, {file = "PyYAML-5.4.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:74c1485f7707cf707a7aef42ef6322b8f97921bd89be2ab6317fd782c2d53183"}, + {file = "PyYAML-5.4.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d483ad4e639292c90170eb6f7783ad19490e7a8defb3e46f97dfe4bacae89122"}, + {file = "PyYAML-5.4.1-cp39-cp39-manylinux2014_s390x.whl", hash = "sha256:fdc842473cd33f45ff6bce46aea678a54e3d21f1b61a7750ce3c498eedfe25d6"}, {file = "PyYAML-5.4.1-cp39-cp39-win32.whl", hash = "sha256:49d4cdd9065b9b6e206d0595fee27a96b5dd22618e7520c33204a4a3239d5b10"}, {file = "PyYAML-5.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:c20cfa2d49991c8b4147af39859b167664f2ad4561704ee74c1de03318e898db"}, {file = "PyYAML-5.4.1.tar.gz", hash = "sha256:607774cbba28732bfa802b54baa7484215f530991055bb562efbed5b2f20a45e"}, @@ -4567,22 +4594,29 @@ rsa = [ {file = "ruamel.yaml.clib-0.2.2-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:73b3d43e04cc4b228fa6fa5d796409ece6fcb53a6c270eb2048109cbcbc3b9c2"}, {file = "ruamel.yaml.clib-0.2.2-cp35-cp35m-macosx_10_6_intel.whl", hash = "sha256:53b9dd1abd70e257a6e32f934ebc482dac5edb8c93e23deb663eac724c30b026"}, {file = "ruamel.yaml.clib-0.2.2-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:839dd72545ef7ba78fd2aa1a5dd07b33696adf3e68fae7f31327161c1093001b"}, + {file = "ruamel.yaml.clib-0.2.2-cp35-cp35m-manylinux2014_aarch64.whl", hash = "sha256:1236df55e0f73cd138c0eca074ee086136c3f16a97c2ac719032c050f7e0622f"}, {file = "ruamel.yaml.clib-0.2.2-cp35-cp35m-win32.whl", hash = "sha256:b1e981fe1aff1fd11627f531524826a4dcc1f26c726235a52fcb62ded27d150f"}, {file = "ruamel.yaml.clib-0.2.2-cp35-cp35m-win_amd64.whl", hash = "sha256:4e52c96ca66de04be42ea2278012a2342d89f5e82b4512fb6fb7134e377e2e62"}, {file = "ruamel.yaml.clib-0.2.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a873e4d4954f865dcb60bdc4914af7eaae48fb56b60ed6daa1d6251c72f5337c"}, {file = "ruamel.yaml.clib-0.2.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:ab845f1f51f7eb750a78937be9f79baea4a42c7960f5a94dde34e69f3cce1988"}, + {file = "ruamel.yaml.clib-0.2.2-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:2fd336a5c6415c82e2deb40d08c222087febe0aebe520f4d21910629018ab0f3"}, {file = "ruamel.yaml.clib-0.2.2-cp36-cp36m-win32.whl", hash = "sha256:e9f7d1d8c26a6a12c23421061f9022bb62704e38211fe375c645485f38df34a2"}, {file = "ruamel.yaml.clib-0.2.2-cp36-cp36m-win_amd64.whl", hash = "sha256:2602e91bd5c1b874d6f93d3086f9830f3e907c543c7672cf293a97c3fabdcd91"}, {file = "ruamel.yaml.clib-0.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:44c7b0498c39f27795224438f1a6be6c5352f82cb887bc33d962c3a3acc00df6"}, {file = "ruamel.yaml.clib-0.2.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:8e8fd0a22c9d92af3a34f91e8a2594eeb35cba90ab643c5e0e643567dc8be43e"}, + {file = "ruamel.yaml.clib-0.2.2-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:75f0ee6839532e52a3a53f80ce64925ed4aed697dd3fa890c4c918f3304bd4f4"}, {file = "ruamel.yaml.clib-0.2.2-cp37-cp37m-win32.whl", hash = "sha256:464e66a04e740d754170be5e740657a3b3b6d2bcc567f0c3437879a6e6087ff6"}, {file = "ruamel.yaml.clib-0.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:52ae5739e4b5d6317b52f5b040b1b6639e8af68a5b8fd606a8b08658fbd0cab5"}, {file = "ruamel.yaml.clib-0.2.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:4df5019e7783d14b79217ad9c56edf1ba7485d614ad5a385d1b3c768635c81c0"}, {file = "ruamel.yaml.clib-0.2.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:5254af7d8bdf4d5484c089f929cb7f5bafa59b4f01d4f48adda4be41e6d29f99"}, + {file = "ruamel.yaml.clib-0.2.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:8be05be57dc5c7b4a0b24edcaa2f7275866d9c907725226cdde46da09367d923"}, {file = "ruamel.yaml.clib-0.2.2-cp38-cp38-win32.whl", hash = "sha256:74161d827407f4db9072011adcfb825b5258a5ccb3d2cd518dd6c9edea9e30f1"}, {file = "ruamel.yaml.clib-0.2.2-cp38-cp38-win_amd64.whl", hash = "sha256:058a1cc3df2a8aecc12f983a48bda99315cebf55a3b3a5463e37bb599b05727b"}, {file = "ruamel.yaml.clib-0.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c6ac7e45367b1317e56f1461719c853fd6825226f45b835df7436bb04031fd8a"}, {file = "ruamel.yaml.clib-0.2.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:b4b0d31f2052b3f9f9b5327024dc629a253a83d8649d4734ca7f35b60ec3e9e5"}, + {file = "ruamel.yaml.clib-0.2.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:1f8c0a4577c0e6c99d208de5c4d3fd8aceed9574bb154d7a2b21c16bb924154c"}, + {file = "ruamel.yaml.clib-0.2.2-cp39-cp39-win32.whl", hash = "sha256:46d6d20815064e8bb023ea8628cfb7402c0f0e83de2c2227a88097e239a7dffd"}, + {file = "ruamel.yaml.clib-0.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:6c0a5dc52fc74eb87c67374a4e554d4761fd42a4d01390b7e868b30d21f4b8bb"}, {file = "ruamel.yaml.clib-0.2.2.tar.gz", hash = "sha256:2d24bd98af676f4990c4d715bcdc2a60b19c56a3fb3a763164d2d8ca0e806ba7"}, ] s3transfer = [ @@ -4610,21 +4644,29 @@ sanic-plugins-framework = [ scikit-learn = [ {file = "scikit-learn-0.24.1.tar.gz", hash = "sha256:a0334a1802e64d656022c3bfab56a73fbd6bf4b1298343f3688af2151810bbdf"}, {file = "scikit_learn-0.24.1-cp36-cp36m-macosx_10_13_x86_64.whl", hash = "sha256:9bed8a1ef133c8e2f13966a542cb8125eac7f4b67dcd234197c827ba9c7dd3e0"}, + {file = "scikit_learn-0.24.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:a36e159a0521e13bbe15ca8c8d038b3a1dd4c7dad18d276d76992e03b92cf643"}, + {file = "scikit_learn-0.24.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:c658432d8a20e95398f6bb95ff9731ce9dfa343fdf21eea7ec6a7edfacd4b4d9"}, {file = "scikit_learn-0.24.1-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:9dfa564ef27e8e674aa1cc74378416d580ac4ede1136c13dd555a87996e13422"}, {file = "scikit_learn-0.24.1-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:9c6097b6a9b2bafc5e0f31f659e6ab5e131383209c30c9e978c5b8abdac5ed2a"}, {file = "scikit_learn-0.24.1-cp36-cp36m-win32.whl", hash = "sha256:7b04691eb2f41d2c68dbda8d1bd3cb4ef421bdc43aaa56aeb6c762224552dfb6"}, {file = "scikit_learn-0.24.1-cp36-cp36m-win_amd64.whl", hash = "sha256:1adf483e91007a87171d7ce58c34b058eb5dab01b5fee6052f15841778a8ecd8"}, {file = "scikit_learn-0.24.1-cp37-cp37m-macosx_10_13_x86_64.whl", hash = "sha256:ddb52d088889f5596bc4d1de981f2eca106b58243b6679e4782f3ba5096fd645"}, + {file = "scikit_learn-0.24.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:a29460499c1e62b7a830bb57ca42e615375a6ab1bcad053cd25b493588348ea8"}, + {file = "scikit_learn-0.24.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:0567a2d29ad08af98653300c623bd8477b448fe66ced7198bef4ed195925f082"}, {file = "scikit_learn-0.24.1-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:99349d77f54e11f962d608d94dfda08f0c9e5720d97132233ebdf35be2858b2d"}, {file = "scikit_learn-0.24.1-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:83b21ff053b1ff1c018a2d24db6dd3ea339b1acfbaa4d9c881731f43748d8b3b"}, {file = "scikit_learn-0.24.1-cp37-cp37m-win32.whl", hash = "sha256:c3deb3b19dd9806acf00cf0d400e84562c227723013c33abefbbc3cf906596e9"}, {file = "scikit_learn-0.24.1-cp37-cp37m-win_amd64.whl", hash = "sha256:d54dbaadeb1425b7d6a66bf44bee2bb2b899fe3e8850b8e94cfb9c904dcb46d0"}, {file = "scikit_learn-0.24.1-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:3c4f07f47c04e81b134424d53c3f5e16dfd7f494e44fd7584ba9ce9de2c5e6c1"}, + {file = "scikit_learn-0.24.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:c13ebac42236b1c46397162471ea1c46af68413000e28b9309f8c05722c65a09"}, + {file = "scikit_learn-0.24.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:4ddd2b6f7449a5d539ff754fa92d75da22de261fd8fdcfb3596799fadf255101"}, {file = "scikit_learn-0.24.1-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:826b92bf45b8ad80444814e5f4ac032156dd481e48d7da33d611f8fe96d5f08b"}, {file = "scikit_learn-0.24.1-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:259ec35201e82e2db1ae2496f229e63f46d7f1695ae68eef9350b00dc74ba52f"}, {file = "scikit_learn-0.24.1-cp38-cp38-win32.whl", hash = "sha256:8772b99d683be8f67fcc04789032f1b949022a0e6880ee7b75a7ec97dbbb5d0b"}, {file = "scikit_learn-0.24.1-cp38-cp38-win_amd64.whl", hash = "sha256:ed9d65594948678827f4ff0e7ae23344e2f2b4cabbca057ccaed3118fdc392ca"}, {file = "scikit_learn-0.24.1-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:8aa1b3ac46b80eaa552b637eeadbbce3be5931e4b5002b964698e33a1b589e1e"}, + {file = "scikit_learn-0.24.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:c7f4eb77504ac586d8ac1bde1b0c04b504487210f95297235311a0ab7edd7e38"}, + {file = "scikit_learn-0.24.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:087dfede39efb06ab30618f9ab55a0397f29c38d63cd0ab88d12b500b7d65fd7"}, {file = "scikit_learn-0.24.1-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:895dbf2030aa7337649e36a83a007df3c9811396b4e2fa672a851160f36ce90c"}, {file = "scikit_learn-0.24.1-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:9a24d1ccec2a34d4cd3f2a1f86409f3f5954cc23d4d2270ba0d03cf018aa4780"}, {file = "scikit_learn-0.24.1-cp39-cp39-win32.whl", hash = "sha256:fab31f48282ebf54dd69f6663cd2d9800096bad1bb67bbc9c9ac84eb77b41972"}, diff --git a/pyproject.toml b/pyproject.toml index 5986fd71f1c1..fca427802846 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,6 +150,7 @@ pydoc-markdown = "^3.5.0" pytest-timeout = "^1.4.2" mypy = "^0.790" bandit = "^1.6.3" +typing-extensions = "^3.7.4" [tool.poetry.extras] spacy = [ "spacy",] From 3a783223d1392121aaada7511f17bd956e1aa27b Mon Sep 17 00:00:00 2001 From: m-vdb Date: Wed, 10 Mar 2021 09:41:58 +0100 Subject: [PATCH 28/42] raise TypeError if ActionRevertFallbackEvents cannot revert to last event --- rasa/core/actions/action.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/rasa/core/actions/action.py b/rasa/core/actions/action.py index 4ff6fb9fe2b7..9495af765919 100644 --- a/rasa/core/actions/action.py +++ b/rasa/core/actions/action.py @@ -782,7 +782,7 @@ def _revert_affirmation_events(tracker: "DialogueStateTracker") -> List[Event]: last_user_event = tracker.get_last_event_for(UserUttered) if not last_user_event: - return revert_events + raise TypeError("Cannot find last event to revert to.") last_user_event = copy.deepcopy(last_user_event) last_user_event.parse_data["intent"]["confidence"] = 1.0 @@ -802,6 +802,9 @@ def _revert_single_affirmation_events() -> List[Event]: def _revert_successful_rephrasing(tracker) -> List[Event]: last_user_event = tracker.get_last_event_for(UserUttered) + if not last_user_event: + raise TypeError("Cannot find last event to revert to.") + last_user_event = copy.deepcopy(last_user_event) return _revert_rephrasing_events() + [last_user_event] From 61765692acb47461f8ebf397ac41bff97577cef5 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Wed, 10 Mar 2021 09:46:09 +0100 Subject: [PATCH 29/42] raise TypeError if TwoStageFallbackAction cannot issue clarification --- rasa/core/actions/two_stage_fallback.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/rasa/core/actions/two_stage_fallback.py b/rasa/core/actions/two_stage_fallback.py index db1889b782ff..5cc54be902eb 100644 --- a/rasa/core/actions/two_stage_fallback.py +++ b/rasa/core/actions/two_stage_fallback.py @@ -1,6 +1,6 @@ import copy import time -from typing import List, Text, Optional, cast +from typing import List, Text, Optional from rasa.core.actions import action from rasa.core.actions.loops import LoopAction @@ -171,7 +171,14 @@ def _second_affirmation_failed(tracker: DialogueStateTracker) -> bool: def _message_clarification(tracker: DialogueStateTracker) -> List[Event]: - clarification = copy.deepcopy(cast(Event, tracker.latest_message)) + latest_message = tracker.latest_message + if not latest_message: + raise TypeError( + "Cannot issue message clarification because " + "latest message is not on tracker." + ) + + clarification = copy.deepcopy(latest_message) clarification.parse_data["intent"]["confidence"] = 1.0 clarification.timestamp = time.time() return [ActionExecuted(ACTION_LISTEN_NAME), clarification] From 0116796e7fcd1be73b59a6146914a8bc6f94981e Mon Sep 17 00:00:00 2001 From: m-vdb Date: Thu, 11 Mar 2021 08:35:03 +0100 Subject: [PATCH 30/42] use consistent error key in /model/test/intents --- rasa/server.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/rasa/server.py b/rasa/server.py index 1602055644ec..6e9c08e69dc0 100644 --- a/rasa/server.py +++ b/rasa/server.py @@ -178,7 +178,9 @@ def decorated(request: Request, *args: Any, **kwargs: Any) -> HTTPResponse: return decorator -def requires_auth(app: Sanic, token: Optional[Text] = None) -> Callable[["SanicView"], "SanicView"]: +def requires_auth( + app: Sanic, token: Optional[Text] = None +) -> Callable[["SanicView"], "SanicView"]: """Wraps a request handler with token authentication.""" def decorator(f: "SanicView") -> "SanicView": @@ -1198,10 +1200,9 @@ async def _evaluate_model_using_test_set( if nlu_model is None: raise ErrorResponse( - HTTPStatus.CONFLICT, "TestingError", "Missing NLU model directory.", + HTTPStatus.CONFLICT, "Conflict", "Missing NLU model directory.", ) - return await run_evaluation( data_path, nlu_model, disable_plotting=True, report_as_dict=True ) From 33ca6b9912914f04d12b23c9e97832e4240f4ae3 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Thu, 11 Mar 2021 08:41:12 +0100 Subject: [PATCH 31/42] fix remaining type issues --- rasa/core/test.py | 4 ++-- rasa/utils/tensorflow/models.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/rasa/core/test.py b/rasa/core/test.py index e33ce8591e32..e32c3eeb954c 100644 --- a/rasa/core/test.py +++ b/rasa/core/test.py @@ -297,7 +297,7 @@ def __init__(self, event: UserUttered, eval_store: EvaluationStore) -> None: self.predicted_entities = eval_store.entity_predictions try: - intent = {"name": eval_store.intent_targets[0]} + intent: Dict[Text, Optional[Text]] = {"name": eval_store.intent_targets[0]} except LookupError: intent = {"name": None} @@ -440,7 +440,7 @@ def _collect_user_uttered_predictions( if intent_gold: user_uttered_eval_store.add_to_store(intent_targets=[intent_gold]) - if predicted_intent: + if predicted_base_intent: user_uttered_eval_store.add_to_store(intent_predictions=[predicted_base_intent]) entity_gold = event.entities diff --git a/rasa/utils/tensorflow/models.py b/rasa/utils/tensorflow/models.py index 7605d3cf6852..2fffcc8b8ea5 100644 --- a/rasa/utils/tensorflow/models.py +++ b/rasa/utils/tensorflow/models.py @@ -385,7 +385,8 @@ def _convert_dense_features( if number_of_dimensions > 1 and ( batch[idx].shape is None or batch[idx].shape[-1] is None ): - shape = [None] * (number_of_dimensions - 1) + [feature_dimension] + shape: List[Optional[int]] = [None] * (number_of_dimensions - 1) + shape.append(feature_dimension) batch[idx].set_shape(shape) return batch[idx], idx + 1 From 2b6aed300419fa12060bed561ff12bfd1e2da1ea Mon Sep 17 00:00:00 2001 From: m-vdb Date: Thu, 11 Mar 2021 08:44:19 +0100 Subject: [PATCH 32/42] bump mypy --- poetry.lock | 40 ++++++++++++++++++++++++---------------- pyproject.toml | 2 +- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/poetry.lock b/poetry.lock index f18db8b60016..e31fbc3dc726 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1437,7 +1437,7 @@ python-versions = "*" [[package]] name = "mypy" -version = "0.790" +version = "0.812" description = "Optional static typing for Python" category = "dev" optional = false @@ -3139,7 +3139,7 @@ transformers = ["transformers"] [metadata] lock-version = "1.1" python-versions = ">=3.6,<3.9" -content-hash = "aadc96af024f6d8e73fce91f00b0a2228417f4af3e898388cc05732f5df0b4c0" +content-hash = "ce36740b1ada645cdf75ec35eecd3eca833f18b5e359ac6b0c6ba583992f1402" [metadata.files] absl-py = [ @@ -3999,20 +3999,28 @@ murmurhash = [ {file = "murmurhash-1.0.5.tar.gz", hash = "sha256:98ec9d727bd998a35385abd56b062cf0cca216725ea7ec5068604ab566f7e97f"}, ] mypy = [ - {file = "mypy-0.790-cp35-cp35m-macosx_10_6_x86_64.whl", hash = "sha256:bd03b3cf666bff8d710d633d1c56ab7facbdc204d567715cb3b9f85c6e94f669"}, - {file = "mypy-0.790-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:2170492030f6faa537647d29945786d297e4862765f0b4ac5930ff62e300d802"}, - {file = "mypy-0.790-cp35-cp35m-win_amd64.whl", hash = "sha256:e86bdace26c5fe9cf8cb735e7cedfe7850ad92b327ac5d797c656717d2ca66de"}, - {file = "mypy-0.790-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:e97e9c13d67fbe524be17e4d8025d51a7dca38f90de2e462243ab8ed8a9178d1"}, - {file = "mypy-0.790-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:0d34d6b122597d48a36d6c59e35341f410d4abfa771d96d04ae2c468dd201abc"}, - {file = "mypy-0.790-cp36-cp36m-win_amd64.whl", hash = "sha256:72060bf64f290fb629bd4a67c707a66fd88ca26e413a91384b18db3876e57ed7"}, - {file = "mypy-0.790-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:eea260feb1830a627fb526d22fbb426b750d9f5a47b624e8d5e7e004359b219c"}, - {file = "mypy-0.790-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:c614194e01c85bb2e551c421397e49afb2872c88b5830e3554f0519f9fb1c178"}, - {file = "mypy-0.790-cp37-cp37m-win_amd64.whl", hash = "sha256:0a0d102247c16ce93c97066443d11e2d36e6cc2a32d8ccc1f705268970479324"}, - {file = "mypy-0.790-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cf4e7bf7f1214826cf7333627cb2547c0db7e3078723227820d0a2490f117a01"}, - {file = "mypy-0.790-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:af4e9ff1834e565f1baa74ccf7ae2564ae38c8df2a85b057af1dbbc958eb6666"}, - {file = "mypy-0.790-cp38-cp38-win_amd64.whl", hash = "sha256:da56dedcd7cd502ccd3c5dddc656cb36113dd793ad466e894574125945653cea"}, - {file = "mypy-0.790-py3-none-any.whl", hash = "sha256:2842d4fbd1b12ab422346376aad03ff5d0805b706102e475e962370f874a5122"}, - {file = "mypy-0.790.tar.gz", hash = "sha256:2b21ba45ad9ef2e2eb88ce4aeadd0112d0f5026418324176fd494a6824b74975"}, + {file = "mypy-0.812-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:a26f8ec704e5a7423c8824d425086705e381b4f1dfdef6e3a1edab7ba174ec49"}, + {file = "mypy-0.812-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:28fb5479c494b1bab244620685e2eb3c3f988d71fd5d64cc753195e8ed53df7c"}, + {file = "mypy-0.812-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:9743c91088d396c1a5a3c9978354b61b0382b4e3c440ce83cf77994a43e8c521"}, + {file = "mypy-0.812-cp35-cp35m-win_amd64.whl", hash = "sha256:d7da2e1d5f558c37d6e8c1246f1aec1e7349e4913d8fb3cb289a35de573fe2eb"}, + {file = "mypy-0.812-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4eec37370483331d13514c3f55f446fc5248d6373e7029a29ecb7b7494851e7a"}, + {file = "mypy-0.812-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:d65cc1df038ef55a99e617431f0553cd77763869eebdf9042403e16089fe746c"}, + {file = "mypy-0.812-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:61a3d5b97955422964be6b3baf05ff2ce7f26f52c85dd88db11d5e03e146a3a6"}, + {file = "mypy-0.812-cp36-cp36m-win_amd64.whl", hash = "sha256:25adde9b862f8f9aac9d2d11971f226bd4c8fbaa89fb76bdadb267ef22d10064"}, + {file = "mypy-0.812-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:552a815579aa1e995f39fd05dde6cd378e191b063f031f2acfe73ce9fb7f9e56"}, + {file = "mypy-0.812-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:499c798053cdebcaa916eef8cd733e5584b5909f789de856b482cd7d069bdad8"}, + {file = "mypy-0.812-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:5873888fff1c7cf5b71efbe80e0e73153fe9212fafdf8e44adfe4c20ec9f82d7"}, + {file = "mypy-0.812-cp37-cp37m-win_amd64.whl", hash = "sha256:9f94aac67a2045ec719ffe6111df543bac7874cee01f41928f6969756e030564"}, + {file = "mypy-0.812-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d23e0ea196702d918b60c8288561e722bf437d82cb7ef2edcd98cfa38905d506"}, + {file = "mypy-0.812-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:674e822aa665b9fd75130c6c5f5ed9564a38c6cea6a6432ce47eafb68ee578c5"}, + {file = "mypy-0.812-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:abf7e0c3cf117c44d9285cc6128856106183938c68fd4944763003decdcfeb66"}, + {file = "mypy-0.812-cp38-cp38-win_amd64.whl", hash = "sha256:0d0a87c0e7e3a9becdfbe936c981d32e5ee0ccda3e0f07e1ef2c3d1a817cf73e"}, + {file = "mypy-0.812-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7ce3175801d0ae5fdfa79b4f0cfed08807af4d075b402b7e294e6aa72af9aa2a"}, + {file = "mypy-0.812-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:b09669bcda124e83708f34a94606e01b614fa71931d356c1f1a5297ba11f110a"}, + {file = "mypy-0.812-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:33f159443db0829d16f0a8d83d94df3109bb6dd801975fe86bacb9bf71628e97"}, + {file = "mypy-0.812-cp39-cp39-win_amd64.whl", hash = "sha256:3f2aca7f68580dc2508289c729bd49ee929a436208d2b2b6aab15745a70a57df"}, + {file = "mypy-0.812-py3-none-any.whl", hash = "sha256:2f9b3407c58347a452fc0736861593e105139b905cca7d097e413453a1d650b4"}, + {file = "mypy-0.812.tar.gz", hash = "sha256:cd07039aa5df222037005b08fbbfd69b3ab0b0bd7a07d7906de75ae52c4e3119"}, ] mypy-extensions = [ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, diff --git a/pyproject.toml b/pyproject.toml index d6353b80f1f3..6cfb186e73ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,7 +148,7 @@ toml = "^0.10.0" pep440-version-utils = "^0.3.0" pydoc-markdown = "^3.5.0" pytest-timeout = "^1.4.2" -mypy = "^0.790" +mypy = "^0.812" bandit = "^1.6.3" typing-extensions = "^3.7.4" From a2eb2242e33a4e127c0b75955f8b80631f1ceb45 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Thu, 11 Mar 2021 09:06:59 +0100 Subject: [PATCH 33/42] fix `rasa train` command and tests --- rasa/cli/train.py | 12 ++++++------ tests/cli/test_rasa_interactive.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/rasa/cli/train.py b/rasa/cli/train.py index c9bbc8a716b2..ebd049ad1dd1 100644 --- a/rasa/cli/train.py +++ b/rasa/cli/train.py @@ -7,7 +7,6 @@ import rasa.cli.arguments.train as train_arguments import rasa.cli.utils -import rasa.train import rasa.utils.common from rasa.core.train import do_compare_training from rasa.shared.utils.cli import print_error @@ -19,6 +18,7 @@ DEFAULT_DOMAIN_PATH, DEFAULT_DATA_PATH, ) +from rasa.train import train as train_all, train_core, train_nlu def add_subparser( @@ -57,13 +57,13 @@ def add_subparser( ) train_nlu_parser.set_defaults(func=run_nlu_training) - train_parser.set_defaults(func=lambda args: train(args, can_exit=True)) + train_parser.set_defaults(func=lambda args: run_training(args, can_exit=True)) train_arguments.set_train_core_arguments(train_core_parser) train_arguments.set_train_nlu_arguments(train_nlu_parser) -def train(args: argparse.Namespace, can_exit: bool = False) -> Optional[Text]: +def run_training(args: argparse.Namespace, can_exit: bool = False) -> Optional[Text]: """Trains a model. Args: @@ -87,7 +87,7 @@ def train(args: argparse.Namespace, can_exit: bool = False) -> Optional[Text]: for f in args.data ] - training_result = rasa.train.train( + training_result = train_all( domain=domain, config=config, training_files=training_files, @@ -148,7 +148,7 @@ def run_core_training( config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS_CORE) - return rasa.train.train_core( + return train_core( domain=args.domain, config=config, stories=story_file, @@ -190,7 +190,7 @@ def run_nlu_training( args.domain, "domain", DEFAULT_DOMAIN_PATH, none_is_valid=True ) - return rasa.train.train_nlu( + return train_nlu( config=config, nlu_data=nlu_data, output=output, diff --git a/tests/cli/test_rasa_interactive.py b/tests/cli/test_rasa_interactive.py index 08127755416c..da62b6d1cc09 100644 --- a/tests/cli/test_rasa_interactive.py +++ b/tests/cli/test_rasa_interactive.py @@ -66,10 +66,10 @@ def test_pass_arguments_to_rasa_train( # Mock actual training mock = Mock(return_value=TrainingResult(code=0)) - monkeypatch.setattr(train, "train", mock.method) + monkeypatch.setattr(train, "train_all", mock.method) # If the `Namespace` object does not have all required fields this will throw - train.train(args) + train.run_training(args) # Assert `train` was actually called mock.method.assert_called_once() From c413f21c07b5b31d97a9b56e673c9b3f011ba4e5 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Thu, 11 Mar 2021 09:20:05 +0100 Subject: [PATCH 34/42] fix docstring lint --- rasa/cli/train.py | 2 -- rasa/core/agent.py | 1 - rasa/core/channels/rasa_chat.py | 2 +- rasa/core/test.py | 2 +- rasa/nlu/classifiers/diet_classifier.py | 1 - rasa/nlu/model.py | 3 ++- rasa/shared/nlu/training_data/lookup_tables_parser.py | 4 +++- 7 files changed, 7 insertions(+), 8 deletions(-) diff --git a/rasa/cli/train.py b/rasa/cli/train.py index ebd049ad1dd1..81121d6dafae 100644 --- a/rasa/cli/train.py +++ b/rasa/cli/train.py @@ -129,7 +129,6 @@ def run_core_training( Returns: Path to a trained model or `None` if training was not successful. """ - output = train_path or args.out args.domain = rasa.cli.utils.get_validated_path( @@ -177,7 +176,6 @@ def run_nlu_training( Returns: Path to a trained model or `None` if training was not successful. """ - output = train_path or args.out config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS_NLU) diff --git a/rasa/core/agent.py b/rasa/core/agent.py index f0c30f961f50..fa3fe0d17e00 100644 --- a/rasa/core/agent.py +++ b/rasa/core/agent.py @@ -523,7 +523,6 @@ async def handle_message( **kwargs, ) -> Optional[List[Dict[Text, Any]]]: """Handle a single message.""" - if not self.is_ready(): logger.info("Ignoring message as there is no agent to handle it.") return None diff --git a/rasa/core/channels/rasa_chat.py b/rasa/core/channels/rasa_chat.py index 888aa822408b..5caeb000a23c 100644 --- a/rasa/core/channels/rasa_chat.py +++ b/rasa/core/channels/rasa_chat.py @@ -35,6 +35,7 @@ def from_credentials(cls, credentials: Optional[Dict[Text, Any]]) -> InputChanne return cls(credentials.get("url")) def __init__(self, url: Optional[Text]) -> None: + """Initialise the channel with attributes.""" self.base_url = url self.jwt_key: Optional[Text] = None self.jwt_algorithm = None @@ -98,7 +99,6 @@ async def _decode_bearer_token(self, bearer_token: Text) -> Optional[Dict]: async def _extract_sender(self, req: Request) -> Optional[Text]: """Fetch user from the Rasa X Admin API.""" - jwt_payload = None if req.headers.get("Authorization"): jwt_payload = await self._decode_bearer_token(req.headers["Authorization"]) diff --git a/rasa/core/test.py b/rasa/core/test.py index e32c3eeb954c..b5cf90f20c51 100644 --- a/rasa/core/test.py +++ b/rasa/core/test.py @@ -288,7 +288,7 @@ class WronglyClassifiedUserUtterance(UserUttered): type_name = "wrong_utterance" def __init__(self, event: UserUttered, eval_store: EvaluationStore) -> None: - + """Set `predicted_intent` and `predicted_entities` attributes.""" try: self.predicted_intent = eval_store.intent_predictions[0] except LookupError: diff --git a/rasa/nlu/classifiers/diet_classifier.py b/rasa/nlu/classifiers/diet_classifier.py index b04130b77d45..8c7c11f5bb88 100644 --- a/rasa/nlu/classifiers/diet_classifier.py +++ b/rasa/nlu/classifiers/diet_classifier.py @@ -881,7 +881,6 @@ def _predict_label( self, predict_out: Optional[Dict[Text, tf.Tensor]] ) -> Tuple[Dict[Text, Any], List[Dict[Text, Any]]]: """Predicts the intent of the provided message.""" - label: Dict[Text, Any] = {"name": None, "id": None, "confidence": 0.0} label_ranking = [] diff --git a/rasa/nlu/model.py b/rasa/nlu/model.py index b34f44301060..ff3ef8363de2 100644 --- a/rasa/nlu/model.py +++ b/rasa/nlu/model.py @@ -85,10 +85,11 @@ def load(model_dir: Text): ) def __init__(self, metadata: Dict[Text, Any]): - + """Set `metadata` attribute.""" self.metadata = metadata def get(self, property_name: Text, default: Any = None) -> Any: + """Proxy function to get property on `metadata` attribute.""" return self.metadata.get(property_name, default) @property diff --git a/rasa/shared/nlu/training_data/lookup_tables_parser.py b/rasa/shared/nlu/training_data/lookup_tables_parser.py index 9f2b6969c42f..f023dba0742e 100644 --- a/rasa/shared/nlu/training_data/lookup_tables_parser.py +++ b/rasa/shared/nlu/training_data/lookup_tables_parser.py @@ -6,7 +6,9 @@ def add_item_to_lookup_tables( item: Text, existing_lookup_tables: List[Dict[Text, Union[Text, List[Text]]]], ) -> None: - """Takes a list of lookup table dictionaries. Finds the one associated + """Add an item to a list of existing lookup tables. + + Takes a list of lookup table dictionaries. Finds the one associated with the current lookup, then adds the item to the list. Args: From 39859b4da0cb7222ebc4f85a7117313326796f93 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Fri, 12 Mar 2021 08:58:37 +0100 Subject: [PATCH 35/42] data in `rasa test` need to have same lenghts in predictions and targets (+better types) --- rasa/core/test.py | 56 +++++++++++++++++++++++------------------------ 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/rasa/core/test.py b/rasa/core/test.py index 6d7bc821f765..57117ddf844b 100644 --- a/rasa/core/test.py +++ b/rasa/core/test.py @@ -3,7 +3,7 @@ import warnings import typing from collections import defaultdict, namedtuple -from typing import Any, Dict, List, Optional, Text, Tuple, cast +from typing import Any, Dict, List, Optional, Text, Tuple from rasa import telemetry from rasa.core.policies.policy import PolicyPrediction @@ -74,6 +74,9 @@ ], ) +PredictionList = List[Optional[Text]] +EntityPredictionList = List[Dict[Text, Any]] + class WrongPredictionException(RasaException, ValueError): """Raised if a wrong prediction is encountered.""" @@ -84,12 +87,12 @@ class EvaluationStore: def __init__( self, - action_predictions: Optional[List[Text]] = None, - action_targets: Optional[List[Text]] = None, - intent_predictions: Optional[List[Text]] = None, - intent_targets: Optional[List[Text]] = None, - entity_predictions: Optional[List[Dict[Text, Any]]] = None, - entity_targets: Optional[List[Dict[Text, Any]]] = None, + action_predictions: Optional[PredictionList] = None, + action_targets: Optional[PredictionList] = None, + intent_predictions: Optional[PredictionList] = None, + intent_targets: Optional[PredictionList] = None, + entity_predictions: Optional[EntityPredictionList] = None, + entity_targets: Optional[EntityPredictionList] = None, ) -> None: self.action_predictions = action_predictions or [] self.action_targets = action_targets or [] @@ -100,12 +103,12 @@ def __init__( def add_to_store( self, - action_predictions: Optional[List[Text]] = None, - action_targets: Optional[List[Text]] = None, - intent_predictions: Optional[List[Text]] = None, - intent_targets: Optional[List[Text]] = None, - entity_predictions: Optional[List[Dict[Text, Any]]] = None, - entity_targets: Optional[List[Dict[Text, Any]]] = None, + action_predictions: Optional[PredictionList] = None, + action_targets: Optional[PredictionList] = None, + intent_predictions: Optional[PredictionList] = None, + intent_targets: Optional[PredictionList] = None, + entity_predictions: Optional[EntityPredictionList] = None, + entity_targets: Optional[EntityPredictionList] = None, ) -> None: """Add items or lists of items to the store""" @@ -136,8 +139,8 @@ def has_prediction_target_mismatch(self) -> bool: @staticmethod def _compare_entities( - entity_predictions: List[Dict[Text, Any]], - entity_targets: List[Dict[Text, Any]], + entity_predictions: EntityPredictionList, + entity_targets: EntityPredictionList, i_pred: int, i_target: int, ) -> int: @@ -175,7 +178,7 @@ def _compare_entities( def _generate_entity_training_data(entity: Dict[Text, Any]) -> Text: return TrainingDataWriter.generate_entity(entity.get("text"), entity) - def serialise(self) -> Tuple[List[Text], List[Text]]: + def serialise(self) -> Tuple[PredictionList, PredictionList]: """Turn targets and predictions to lists of equal size for sklearn.""" texts = sorted( list( @@ -305,10 +308,7 @@ def __init__(self, event: UserUttered, eval_store: EvaluationStore) -> None: self.predicted_entities = eval_store.entity_predictions - try: - intent: Dict[Text, Optional[Text]] = {"name": eval_store.intent_targets[0]} - except LookupError: - intent = {"name": None} + intent = {"name": eval_store.intent_targets[0]} super().__init__( event.text, @@ -364,7 +364,7 @@ async def _create_data_generator( def _clean_entity_results( text: Text, entity_results: List[Dict[Text, Any]] -) -> List[Dict[Text, Any]]: +) -> EntityPredictionList: """Extract only the token variables from an entity dict.""" cleaned_entities = [] @@ -447,10 +447,9 @@ def _collect_user_uttered_predictions( if intent_gold != predicted_base_intent: predicted_base_intent = _get_full_retrieval_intent(predicted) - if intent_gold: - user_uttered_eval_store.add_to_store(intent_targets=[intent_gold]) - if predicted_base_intent: - user_uttered_eval_store.add_to_store(intent_predictions=[predicted_base_intent]) + user_uttered_eval_store.add_to_store( + intent_targets=[intent_gold], intent_predictions=[predicted_base_intent] + ) entity_gold = event.entities predicted_entities = predicted.get(ENTITIES) @@ -533,8 +532,7 @@ def _collect_action_executed_predictions( gold_action_name = event.action_name gold_action_text = event.action_text - # FIXME: mypy doesn't pick up typing guard in `ActionExecuted.__init__` - gold = cast(Text, gold_action_name or gold_action_text) + gold = gold_action_name or gold_action_text policy_entity_result = None @@ -952,7 +950,9 @@ def _log_evaluation_table( def _plot_story_evaluation( - targets: List[Text], predictions: List[Text], output_directory: Optional[Text] + targets: PredictionList, + predictions: PredictionList, + output_directory: Optional[Text], ) -> None: """Plot a confusion matrix of story evaluation.""" from sklearn.metrics import confusion_matrix From 3db8c33d327a92613a4158cbc4502468e9a8b53f Mon Sep 17 00:00:00 2001 From: m-vdb Date: Fri, 12 Mar 2021 09:24:56 +0100 Subject: [PATCH 36/42] add_item_to_lookup_tables() raises if trying to add an item to unloaded file --- .../nlu/training_data/lookup_tables_parser.py | 6 +++++- .../test_lookup_tables_parser.py | 19 ++++++++++++++++--- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/rasa/shared/nlu/training_data/lookup_tables_parser.py b/rasa/shared/nlu/training_data/lookup_tables_parser.py index f023dba0742e..6fdf25f15588 100644 --- a/rasa/shared/nlu/training_data/lookup_tables_parser.py +++ b/rasa/shared/nlu/training_data/lookup_tables_parser.py @@ -15,6 +15,10 @@ def add_item_to_lookup_tables( title: Name of the lookup item. item: The lookup item. existing_lookup_tables: Existing lookup items that will be extended. + + Raises: + TypeError: in case we're trying to add a lookup table element to a file. + This is an internal error that is indicative of a parsing error. """ matches = [table for table in existing_lookup_tables if table["name"] == title] if not matches: @@ -22,5 +26,5 @@ def add_item_to_lookup_tables( else: elements = matches[0]["elements"] if not isinstance(elements, list): - elements = matches[0]["elements"] = [elements] + raise TypeError("Cannot add a lookup table element to an unloaded file.") elements.append(item) diff --git a/tests/shared/nlu/training_data/test_lookup_tables_parser.py b/tests/shared/nlu/training_data/test_lookup_tables_parser.py index 6bbc49bc2007..5d8cb8d3a213 100644 --- a/tests/shared/nlu/training_data/test_lookup_tables_parser.py +++ b/tests/shared/nlu/training_data/test_lookup_tables_parser.py @@ -1,3 +1,5 @@ +import pytest + import rasa.shared.nlu.training_data.lookup_tables_parser as lookup_tables_parser @@ -5,11 +7,22 @@ def test_add_item_to_lookup_tables(): lookup_item_title = "additional_currencies" lookup_examples = ["Peso", "Euro", "Dollar"] - result = [] + lookup_tables = [] for example in lookup_examples: lookup_tables_parser.add_item_to_lookup_tables( - lookup_item_title, example, result + lookup_item_title, example, lookup_tables ) - assert result == [{"name": lookup_item_title, "elements": lookup_examples}] + assert lookup_tables == [{"name": lookup_item_title, "elements": lookup_examples}] + + +def test_add_item_to_lookup_tables_unloaded_file(): + lookup_item_title = "additional_currencies" + + lookup_tables = [{"name": lookup_item_title, "elements": "lookup.txt"}] + + with pytest.raises(TypeError): + lookup_tables_parser.add_item_to_lookup_tables( + lookup_item_title, "Pound", lookup_tables + ) From 34784fce94086127c1d732f2cd91a4980a23d557 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Fri, 12 Mar 2021 09:25:04 +0100 Subject: [PATCH 37/42] fix docstring lint in rasa.core.test --- rasa/core/test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rasa/core/test.py b/rasa/core/test.py index 57117ddf844b..efc60300fd50 100644 --- a/rasa/core/test.py +++ b/rasa/core/test.py @@ -94,6 +94,7 @@ def __init__( entity_predictions: Optional[EntityPredictionList] = None, entity_targets: Optional[EntityPredictionList] = None, ) -> None: + """Initialize store attributes.""" self.action_predictions = action_predictions or [] self.action_targets = action_targets or [] self.intent_predictions = intent_predictions or [] @@ -110,8 +111,7 @@ def add_to_store( entity_predictions: Optional[EntityPredictionList] = None, entity_targets: Optional[EntityPredictionList] = None, ) -> None: - """Add items or lists of items to the store""" - + """Add items or lists of items to the store.""" self.action_predictions.extend(action_predictions or []) self.action_targets.extend(action_targets or []) self.intent_targets.extend(intent_targets or []) @@ -120,7 +120,7 @@ def add_to_store( self.entity_targets.extend(entity_targets or []) def merge_store(self, other: "EvaluationStore") -> None: - """Add the contents of other to self""" + """Add the contents of other to self.""" self.add_to_store( action_predictions=other.action_predictions, action_targets=other.action_targets, From f89b953866d0c85a1228fca6f7852329df5cf522 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Fri, 12 Mar 2021 14:00:42 +0100 Subject: [PATCH 38/42] fix wrong function call in test_forms.py --- tests/core/actions/test_forms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/core/actions/test_forms.py b/tests/core/actions/test_forms.py index d59914318263..2efccbbf45ed 100644 --- a/tests/core/actions/test_forms.py +++ b/tests/core/actions/test_forms.py @@ -859,7 +859,7 @@ async def test_trigger_slot_mapping_applies( ], ) - slot_values = form.extract_other_slots(tracker, domain, "some_slot") + slot_values = form.extract_other_slots(tracker, domain) assert slot_values == {slot_filled_by_trigger_mapping: expected_value} @@ -906,7 +906,7 @@ async def test_trigger_slot_mapping_does_not_apply(trigger_slot_mapping: Dict): ], ) - slot_values = form.extract_other_slots(tracker, domain, "some_slot") + slot_values = form.extract_other_slots(tracker, domain) assert slot_values == {} @@ -1246,7 +1246,7 @@ def test_extract_other_slots_with_entity( ], ) - slot_values = form.extract_other_slots(tracker, domain, "some_slot") + slot_values = form.extract_other_slots(tracker, domain) # check that the value was extracted for non requested slot assert slot_values == expected_slot_values From 277b8a39caee4c8937f57b6e502ef87bd0bafb1c Mon Sep 17 00:00:00 2001 From: Maxime Vdb Date: Mon, 15 Mar 2021 17:40:29 +0100 Subject: [PATCH 39/42] Update type definition of TrackerActiveLoop Co-authored-by: Tobias Wochinger --- rasa/shared/core/trackers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/shared/core/trackers.py b/rasa/shared/core/trackers.py index 51ff0379b25c..0edaa2aef250 100644 --- a/rasa/shared/core/trackers.py +++ b/rasa/shared/core/trackers.py @@ -73,7 +73,7 @@ TrackerActiveLoop = TypedDict( "TrackerActiveLoop", { - LOOP_NAME: Text, + LOOP_NAME: Optional[Text], LOOP_INTERRUPTED: bool, LOOP_REJECTED: bool, TRIGGER_MESSAGE: Dict, From f124d6282c67d9fec1b34467e813cb7cf56c4688 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Mon, 15 Mar 2021 18:22:42 +0100 Subject: [PATCH 40/42] keep local imports --- rasa/cli/train.py | 7 ++++++- tests/cli/test_rasa_interactive.py | 3 ++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/rasa/cli/train.py b/rasa/cli/train.py index 81121d6dafae..7f8cb4f8f2b8 100644 --- a/rasa/cli/train.py +++ b/rasa/cli/train.py @@ -18,7 +18,6 @@ DEFAULT_DOMAIN_PATH, DEFAULT_DATA_PATH, ) -from rasa.train import train as train_all, train_core, train_nlu def add_subparser( @@ -74,6 +73,8 @@ def run_training(args: argparse.Namespace, can_exit: bool = False) -> Optional[T Returns: Path to a trained model or `None` if training was not successful. """ + from rasa import train as train_all + domain = rasa.cli.utils.get_validated_path( args.domain, "domain", DEFAULT_DOMAIN_PATH, none_is_valid=True ) @@ -129,6 +130,8 @@ def run_core_training( Returns: Path to a trained model or `None` if training was not successful. """ + from rasa.train import train_core + output = train_path or args.out args.domain = rasa.cli.utils.get_validated_path( @@ -176,6 +179,8 @@ def run_nlu_training( Returns: Path to a trained model or `None` if training was not successful. """ + from rasa.train import train_nlu + output = train_path or args.out config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS_NLU) diff --git a/tests/cli/test_rasa_interactive.py b/tests/cli/test_rasa_interactive.py index da62b6d1cc09..b49d39c130c0 100644 --- a/tests/cli/test_rasa_interactive.py +++ b/tests/cli/test_rasa_interactive.py @@ -6,6 +6,7 @@ from _pytest.monkeypatch import MonkeyPatch from _pytest.pytester import RunResult +import rasa from rasa.core.train import do_interactive_learning from rasa.core.training import interactive as interactive_learning from rasa.cli import interactive, train @@ -66,7 +67,7 @@ def test_pass_arguments_to_rasa_train( # Mock actual training mock = Mock(return_value=TrainingResult(code=0)) - monkeypatch.setattr(train, "train_all", mock.method) + monkeypatch.setattr(rasa, "train", mock.method) # If the `Namespace` object does not have all required fields this will throw train.run_training(args) From f9ac2de4fadab13e60ecd1604f037675c1d91033 Mon Sep 17 00:00:00 2001 From: m-vdb Date: Mon, 15 Mar 2021 18:29:47 +0100 Subject: [PATCH 41/42] raise new ChannelConfigError when rasa_chat.py channel is not configured correctly --- rasa/core/channels/rasa_chat.py | 3 ++- rasa/core/exceptions.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/rasa/core/channels/rasa_chat.py b/rasa/core/channels/rasa_chat.py index 5caeb000a23c..3ca5a4444241 100644 --- a/rasa/core/channels/rasa_chat.py +++ b/rasa/core/channels/rasa_chat.py @@ -11,6 +11,7 @@ from rasa.core.channels.channel import InputChannel from rasa.core.channels.rest import RestInput from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT +from rasa.core.exceptions import ChannelConfigError from sanic.request import Request logger = logging.getLogger(__name__) @@ -74,7 +75,7 @@ async def _fetch_public_key(self) -> None: def _decode_jwt(self, bearer_token: Text) -> Dict: if self.jwt_key is None: - raise TypeError( + raise ChannelConfigError( "JWT public key is `None`. This is likely caused " "by an error when retrieving the public key from Rasa X." ) diff --git a/rasa/core/exceptions.py b/rasa/core/exceptions.py index 57b404a39c3a..f6e780ad52dc 100644 --- a/rasa/core/exceptions.py +++ b/rasa/core/exceptions.py @@ -29,3 +29,7 @@ class AgentNotReady(RasaCoreException): def __init__(self, message: Text) -> None: self.message = message super(AgentNotReady, self).__init__() + + +class ChannelConfigError(RasaCoreException): + """Raised if a channel is not configured correctly.""" From fa94e460b4946bca1797a04b89e1165c1893515c Mon Sep 17 00:00:00 2001 From: m-vdb Date: Tue, 16 Mar 2021 10:43:13 +0100 Subject: [PATCH 42/42] fix missing docstrings after merges --- rasa/core/exceptions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rasa/core/exceptions.py b/rasa/core/exceptions.py index f6e780ad52dc..6511235f1b0a 100644 --- a/rasa/core/exceptions.py +++ b/rasa/core/exceptions.py @@ -11,6 +11,7 @@ class UnsupportedDialogueModelError(RasaCoreException): """ def __init__(self, message: Text, model_version: Optional[Text] = None) -> None: + """Initialize message and model_version attributes.""" self.message = message self.model_version = model_version super(UnsupportedDialogueModelError, self).__init__() @@ -27,6 +28,7 @@ class AgentNotReady(RasaCoreException): will be thrown.""" def __init__(self, message: Text) -> None: + """Initialize message attribute.""" self.message = message super(AgentNotReady, self).__init__()