Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to mypy part1 #6511

Merged
merged 53 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
1699906
cleanup type error: [dict-item]
m-vdb Aug 31, 2020
0da7e77
cleanup type error: [list-item]
m-vdb Aug 31, 2020
a797f7b
cleanup type error: [call-arg]
m-vdb Aug 31, 2020
4e0fa91
fix import errors
m-vdb Aug 31, 2020
1981141
fix tests after import/signature changes
m-vdb Sep 1, 2020
b4657ad
cleanup some type errors: [index]
m-vdb Sep 1, 2020
2d5d621
clear errors in rasa.nlu.model
m-vdb Sep 2, 2020
cdf7a89
Merge branch 'switch-to-mypy' into switch-to-mypy-part1
m-vdb Oct 12, 2020
31f728c
fix import issues in rasa.cli.train
m-vdb Oct 12, 2020
54b8b11
disable some error codes
m-vdb Oct 12, 2020
c3d98df
Merge branch 'switch-to-mypy' into switch-to-mypy-part1
m-vdb Oct 15, 2020
c55264f
update extract_requested_slot() docstring
m-vdb Oct 15, 2020
04b3ea4
more precise exception message
m-vdb Oct 15, 2020
39a4d79
use HTTPStatus constants in rasa.server
m-vdb Oct 15, 2020
33834e9
/model/test/intents returns the right status code in case of a conflict
m-vdb Oct 15, 2020
22744f1
Merge branch 'switch-to-mypy' into switch-to-mypy-part1
m-vdb Oct 16, 2020
f8a817a
rename methods to avoid obfuscation
m-vdb Oct 16, 2020
5eae1b4
add type hint for tracker.latest_message
m-vdb Oct 16, 2020
8a3ab2f
Merge branch 'master' into switch-to-mypy-part1
m-vdb Oct 19, 2020
2c67668
add precise type definition for DialogueStateTracker.active_loop
m-vdb Oct 19, 2020
c7d33ed
Merge branch 'master' into switch-to-mypy-part1
m-vdb Nov 11, 2020
1b9e44c
add FIXMEs for type issues for now
m-vdb Nov 11, 2020
cfe194f
fix black formatting
m-vdb Nov 11, 2020
ed9f7d3
fix TypedDict import
m-vdb Nov 11, 2020
cacfa0e
Merge branch 'master' into switch-to-mypy-part1
m-vdb Nov 11, 2020
297c1dc
simpler type annotation for SanicView
m-vdb Nov 18, 2020
89b5c6a
Better error message when raising TypeError for missing JWT
m-vdb Nov 18, 2020
f3c2bfb
raise exception when failing to decode JWT key in Rasa chat
m-vdb Nov 18, 2020
eb897f6
change type of Message.time to int
m-vdb Nov 18, 2020
0226600
remove useless import
m-vdb Nov 19, 2020
210d3a9
more precise type for DialogStateTracker.latest_message
m-vdb Nov 19, 2020
ab748f6
explicit cast() in StoryStep._or_string()
m-vdb Nov 19, 2020
80a5826
Merge branch 'main' into switch-to-mypy-part1
m-vdb Mar 10, 2021
de5b428
fix merge issue in pyproject.toml
m-vdb Mar 10, 2021
3a78322
raise TypeError if ActionRevertFallbackEvents cannot revert to last e…
m-vdb Mar 10, 2021
6176569
raise TypeError if TwoStageFallbackAction cannot issue clarification
m-vdb Mar 10, 2021
768068b
Merge branch 'main' into switch-to-mypy-part1
m-vdb Mar 11, 2021
0116796
use consistent error key in /model/test/intents
m-vdb Mar 11, 2021
33ca6b9
fix remaining type issues
m-vdb Mar 11, 2021
2b6aed3
bump mypy
m-vdb Mar 11, 2021
a2eb224
fix `rasa train` command and tests
m-vdb Mar 11, 2021
c413f21
fix docstring lint
m-vdb Mar 11, 2021
cbedef1
Merge branch 'main' into switch-to-mypy-part1
m-vdb Mar 12, 2021
39859b4
data in `rasa test` need to have same lenghts in predictions and targ…
m-vdb Mar 12, 2021
3db8c33
add_item_to_lookup_tables() raises if trying to add an item to unload…
m-vdb Mar 12, 2021
34784fc
fix docstring lint in rasa.core.test
m-vdb Mar 12, 2021
f89b953
fix wrong function call in test_forms.py
m-vdb Mar 12, 2021
277b8a3
Update type definition of TrackerActiveLoop
m-vdb Mar 15, 2021
b9705f4
Merge branch 'main' into switch-to-mypy-part1
m-vdb Mar 15, 2021
f124d62
keep local imports
m-vdb Mar 15, 2021
f9ac2de
raise new ChannelConfigError when rasa_chat.py channel is not configu…
m-vdb Mar 15, 2021
878547e
Merge branch 'main' into switch-to-mypy-part1
m-vdb Mar 16, 2021
fa94e46
fix missing docstrings after merges
m-vdb Mar 16, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ types:
# FIXME: working our way towards removing these
# see https://github.com/RasaHQ/rasa/pull/6470
# the list below is sorted by the number of errors for each error code, in decreasing order
poetry run mypy rasa --disable-error-code arg-type \
MYPYPATH=./stubs poetry run mypy rasa --disable-error-code arg-type \
--disable-error-code assignment \
--disable-error-code var-annotated \
--disable-error-code return-value \
Expand All @@ -116,12 +116,9 @@ types:
--disable-error-code index \
--disable-error-code misc \
--disable-error-code return \
--disable-error-code call-arg \
--disable-error-code type-var \
--disable-error-code list-item \
--disable-error-code has-type \
--disable-error-code valid-type \
--disable-error-code dict-item \
--disable-error-code no-redef \
--disable-error-code func-returns-value

Expand Down
3 changes: 3 additions & 0 deletions changelog/6511.misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
API endpoint `POST /model/test/intents` now returns HTTP 409 status
code in case it cannot find the NLU model directory, instead of an
HTTP 500 status.
81 changes: 65 additions & 16 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ toml = "^0.10.0"
pep440-version-utils = "^0.3.0"
pydoc-markdown = "^3.5.0"
pytest-timeout = "^1.4.2"
mypy = "^0.790"
mypy = "^0.812"
bandit = "^1.6.3"
typing-extensions = "^3.7.4"
wochinge marked this conversation as resolved.
Show resolved Hide resolved

[tool.poetry.extras]
spacy = [ "spacy",]
Expand Down
6 changes: 5 additions & 1 deletion rasa/cli/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def interactive(args: argparse.Namespace) -> None:
"data or a model containing core data."
)

zipped_model = train.train_core(args) if args.core_only else train.train(args)
zipped_model = (
train.run_core_training(args)
if args.core_only
else train.run_training(args)
)
if not zipped_model:
rasa.shared.utils.cli.print_error_and_exit(
"Could not train an initial model. Either pass paths "
Expand Down
21 changes: 10 additions & 11 deletions rasa/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import rasa.cli.arguments.train as train_arguments

import rasa.cli.utils
import rasa.utils.common
from rasa.core.train import do_compare_training
from rasa.shared.utils.cli import print_error
from rasa.shared.constants import (
CONFIG_MANDATORY_KEYS_CORE,
Expand All @@ -16,7 +18,6 @@
DEFAULT_DOMAIN_PATH,
DEFAULT_DATA_PATH,
)
import rasa.utils.common


def add_subparser(
Expand Down Expand Up @@ -45,23 +46,23 @@ def add_subparser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
help="Trains a Rasa Core model using your stories.",
)
train_core_parser.set_defaults(func=train_core)
train_core_parser.set_defaults(func=run_core_training)

train_nlu_parser = train_subparsers.add_parser(
"nlu",
parents=parents,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
help="Trains a Rasa NLU model using your NLU data.",
)
train_nlu_parser.set_defaults(func=train_nlu)
train_nlu_parser.set_defaults(func=run_nlu_training)

train_parser.set_defaults(func=lambda args: train(args, can_exit=True))
train_parser.set_defaults(func=lambda args: run_training(args, can_exit=True))

train_arguments.set_train_core_arguments(train_core_parser)
train_arguments.set_train_nlu_arguments(train_nlu_parser)


def train(args: argparse.Namespace, can_exit: bool = False) -> Optional[Text]:
def run_training(args: argparse.Namespace, can_exit: bool = False) -> Optional[Text]:
"""Trains a model.

Args:
Expand All @@ -72,7 +73,7 @@ def train(args: argparse.Namespace, can_exit: bool = False) -> Optional[Text]:
Returns:
Path to a trained model or `None` if training was not successful.
"""
import rasa
from rasa import train as train_all

domain = rasa.cli.utils.get_validated_path(
args.domain, "domain", DEFAULT_DOMAIN_PATH, none_is_valid=True
Expand All @@ -87,7 +88,7 @@ def train(args: argparse.Namespace, can_exit: bool = False) -> Optional[Text]:
for f in args.data
]

training_result = rasa.train(
training_result = train_all(
domain=domain,
config=config,
training_files=training_files,
Expand Down Expand Up @@ -117,7 +118,7 @@ def _model_for_finetuning(args: argparse.Namespace) -> Optional[Text]:
return args.finetune


def train_core(
def run_core_training(
args: argparse.Namespace, train_path: Optional[Text] = None
) -> Optional[Text]:
"""Trains a Rasa Core model only.
Expand Down Expand Up @@ -161,14 +162,12 @@ def train_core(
finetuning_epoch_fraction=args.epoch_fraction,
)
else:
from rasa.core.train import do_compare_training

rasa.utils.common.run_in_loop(
do_compare_training(args, story_file, additional_arguments)
)


def train_nlu(
def run_nlu_training(
args: argparse.Namespace, train_path: Optional[Text] = None
) -> Optional[Text]:
"""Trains an NLU model.
Expand Down
14 changes: 10 additions & 4 deletions rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,17 +778,20 @@ def has_user_affirmed(tracker: "DialogueStateTracker") -> bool:
def _revert_affirmation_events(tracker: "DialogueStateTracker") -> List[Event]:
revert_events = _revert_single_affirmation_events()

last_user_event = tracker.get_last_event_for(UserUttered)
last_user_event = copy.deepcopy(last_user_event)
last_user_event.parse_data["intent"]["confidence"] = 1.0

# User affirms the rephrased intent
rephrased_intent = tracker.last_executed_action_has(
name=ACTION_DEFAULT_ASK_REPHRASE_NAME, skip=1
)
if rephrased_intent:
revert_events += _revert_rephrasing_events()

last_user_event = tracker.get_last_event_for(UserUttered)
if not last_user_event:
raise TypeError("Cannot find last event to revert to.")

last_user_event = copy.deepcopy(last_user_event)
last_user_event.parse_data["intent"]["confidence"] = 1.0

return revert_events + [last_user_event]


Expand All @@ -804,6 +807,9 @@ def _revert_single_affirmation_events() -> List[Event]:

def _revert_successful_rephrasing(tracker) -> List[Event]:
last_user_event = tracker.get_last_event_for(UserUttered)
if not last_user_event:
raise TypeError("Cannot find last event to revert to.")

last_user_event = copy.deepcopy(last_user_event)
return _revert_rephrasing_events() + [last_user_event]

Expand Down
20 changes: 16 additions & 4 deletions rasa/core/actions/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,20 @@ def get_slot_to_fill(self, tracker: "DialogueStateTracker") -> Optional[str]:
)

def extract_requested_slot(
self, tracker: "DialogueStateTracker", domain: Domain
self, tracker: "DialogueStateTracker", domain: Domain, slot_to_fill: Text,
m-vdb marked this conversation as resolved.
Show resolved Hide resolved
) -> Dict[Text, Any]:
"""Extracts the value of requested slot from a user input else return `None`."""
slot_to_fill = self.get_slot_to_fill(tracker)
"""Extract the value of requested slot from a user input else return `None`.

Args:
tracker: a DialogueStateTracker instance
domain: the current domain
slot_to_fill: the name of the slot to fill

Returns:
a dictionary with one key being the name of the slot to fill
and its value being the slot value, or an empty dictionary
if no slot value was found.
"""
logger.debug(f"Trying to extract requested slot '{slot_to_fill}' ...")

# get mapping for requested slot
Expand Down Expand Up @@ -477,7 +487,9 @@ async def validate(
# extract requested slot
slot_to_fill = self.get_slot_to_fill(tracker)
if slot_to_fill:
slot_values.update(self.extract_requested_slot(tracker, domain))
slot_values.update(
self.extract_requested_slot(tracker, domain, slot_to_fill)
)

validation_events = await self.validate_slots(
slot_values, tracker, domain, output_channel, nlg
Expand Down
9 changes: 8 additions & 1 deletion rasa/core/actions/two_stage_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,14 @@ def _second_affirmation_failed(tracker: DialogueStateTracker) -> bool:


def _message_clarification(tracker: DialogueStateTracker) -> List[Event]:
clarification = copy.deepcopy(tracker.latest_message)
latest_message = tracker.latest_message
if not latest_message:
raise TypeError(
"Cannot issue message clarification because "
"latest message is not on tracker."
)

clarification = copy.deepcopy(latest_message)
clarification.parse_data["intent"]["confidence"] = 1.0
clarification.timestamp = time.time()
return [ActionExecuted(ACTION_LISTEN_NAME), clarification]
7 changes: 2 additions & 5 deletions rasa/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,12 +523,9 @@ async def handle_message(
**kwargs,
) -> Optional[List[Dict[Text, Any]]]:
"""Handle a single message."""

def noop(_: Any) -> None:
logger.info("Ignoring message as there is no agent to handle it.")

if not self.is_ready():
return noop(message)
logger.info("Ignoring message as there is no agent to handle it.")
return None

processor = self.create_processor(message_preprocessor)

Expand Down
2 changes: 1 addition & 1 deletion rasa/core/channels/hangouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ async def _persist_message(self, message: Dict) -> None:
new_messages = self._combine_cards(text_card, message)

elif msg_new == "text":
new_messages = {"text": message.get("text")}
new_messages = {"text": message["text"]}
else:
new_messages = message

Expand Down
17 changes: 11 additions & 6 deletions rasa/core/channels/rasa_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import logging
from sanic.exceptions import abort
import jwt
import jwt.exceptions

from rasa.core import constants
from rasa.core.channels.channel import InputChannel
from rasa.core.channels.rest import RestInput
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
from rasa.core.exceptions import ChannelConfigError
from sanic.request import Request

logger = logging.getLogger(__name__)
Expand All @@ -34,8 +36,9 @@ def from_credentials(cls, credentials: Optional[Dict[Text, Any]]) -> InputChanne
return cls(credentials.get("url"))

def __init__(self, url: Optional[Text]) -> None:
"""Initialise the channel with attributes."""
self.base_url = url
self.jwt_key = None
self.jwt_key: Optional[Text] = None
self.jwt_algorithm = None

async def _fetch_public_key(self) -> None:
Expand Down Expand Up @@ -71,6 +74,12 @@ async def _fetch_public_key(self) -> None:
)

def _decode_jwt(self, bearer_token: Text) -> Dict:
if self.jwt_key is None:
raise ChannelConfigError(
"JWT public key is `None`. This is likely caused "
"by an error when retrieving the public key from Rasa X."
)

authorization_header_value = bearer_token.replace(
constants.BEARER_TOKEN_PREFIX, ""
)
Expand All @@ -82,19 +91,15 @@ async def _decode_bearer_token(self, bearer_token: Text) -> Optional[Dict]:
if self.jwt_key is None:
await self._fetch_public_key()

# noinspection PyBroadException
try:
return self._decode_jwt(bearer_token)
except jwt.exceptions.InvalidSignatureError:
except jwt.InvalidSignatureError:
logger.error("JWT public key invalid, fetching new one.")
await self._fetch_public_key()
return self._decode_jwt(bearer_token)
except Exception:
logger.exception("Failed to decode bearer token.")

async def _extract_sender(self, req: Request) -> Optional[Text]:
"""Fetch user from the Rasa X Admin API."""

jwt_payload = None
if req.headers.get("Authorization"):
jwt_payload = await self._decode_bearer_token(req.headers["Authorization"])
Expand Down
4 changes: 4 additions & 0 deletions rasa/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@ class AgentNotReady(RasaCoreException):
def __init__(self, message: Text) -> None:
self.message = message
super(AgentNotReady, self).__init__()


class ChannelConfigError(RasaCoreException):
"""Raised if a channel is not configured correctly."""
7 changes: 5 additions & 2 deletions rasa/core/featurizers/single_state_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,15 @@ def convert_to_dict(feature_states: List[Text]) -> Dict[Text, int]:
def _state_features_for_attribute(
self, sub_state: SubState, attribute: Text
) -> Dict[Text, int]:
# FIXME: the code below is not type-safe, but fixing it
# would require more refactoring, for instance using
wochinge marked this conversation as resolved.
Show resolved Hide resolved
# data classes in our states
if attribute in {INTENT, ACTION_NAME}:
return {sub_state[attribute]: 1}
return {sub_state[attribute]: 1} # type: ignore[dict-item]
elif attribute == ENTITIES:
return {entity: 1 for entity in sub_state.get(ENTITIES, [])}
elif attribute == ACTIVE_LOOP:
return {sub_state["name"]: 1}
return {sub_state["name"]: 1} # type: ignore[dict-item]
elif attribute == SLOTS:
return {
f"{slot_name}_{i}": value
Expand Down
3 changes: 1 addition & 2 deletions rasa/core/policies/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,7 @@ def load(cls, path: Union[Text, Path], **kwargs: Any) -> "Policy":
)
return cls()

@staticmethod
def _default_predictions(domain: Domain) -> List[float]:
def _default_predictions(self, domain: Domain) -> List[float]:
"""Creates a list of zeros.

Args:
Expand Down
8 changes: 4 additions & 4 deletions rasa/core/policies/rule_policy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import List, Dict, Text, Optional, Any, Set, TYPE_CHECKING, Tuple
from typing import Any, List, Dict, Text, Optional, Set, Tuple, TYPE_CHECKING, Union

from tqdm import tqdm
import numpy as np
Expand Down Expand Up @@ -582,9 +582,7 @@ def train(

# only consider original trackers (no augmented ones)
training_trackers = [
t
for t in training_trackers
if not hasattr(t, "is_augmented") or not t.is_augmented
t for t in training_trackers if not getattr(t, "is_augmented", False)
]
# only use trackers from rule-based training data
rule_trackers = [t for t in training_trackers if t.is_rule_tracker]
Expand Down Expand Up @@ -772,6 +770,8 @@ def _find_action_from_loop_happy_path(
)
return ACTION_LISTEN_NAME

return None

def _find_action_from_rules(
self,
tracker: DialogueStateTracker,
Expand Down
Loading