Skip to content

Commit

Permalink
Merge pull request #6511 from RasaHQ/switch-to-mypy-part1
Browse files Browse the repository at this point in the history
Switch to mypy part1
  • Loading branch information
Maxime Vdb authored Mar 16, 2021
2 parents bfa97e1 + fa94e46 commit f7490fe
Show file tree
Hide file tree
Showing 38 changed files with 339 additions and 167 deletions.
5 changes: 1 addition & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ types:
# FIXME: working our way towards removing these
# see https://github.com/RasaHQ/rasa/pull/6470
# the list below is sorted by the number of errors for each error code, in decreasing order
poetry run mypy rasa --disable-error-code arg-type \
MYPYPATH=./stubs poetry run mypy rasa --disable-error-code arg-type \
--disable-error-code assignment \
--disable-error-code var-annotated \
--disable-error-code return-value \
Expand All @@ -116,12 +116,9 @@ types:
--disable-error-code index \
--disable-error-code misc \
--disable-error-code return \
--disable-error-code call-arg \
--disable-error-code type-var \
--disable-error-code list-item \
--disable-error-code has-type \
--disable-error-code valid-type \
--disable-error-code dict-item \
--disable-error-code no-redef \
--disable-error-code func-returns-value

Expand Down
3 changes: 3 additions & 0 deletions changelog/6511.misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
API endpoint `POST /model/test/intents` now returns HTTP 409 status
code in case it cannot find the NLU model directory, instead of an
HTTP 500 status.
81 changes: 65 additions & 16 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ toml = "^0.10.0"
pep440-version-utils = "^0.3.0"
pydoc-markdown = "^3.5.0"
pytest-timeout = "^1.4.2"
mypy = "^0.790"
mypy = "^0.812"
bandit = "^1.6.3"
typing-extensions = "^3.7.4"

[tool.poetry.extras]
spacy = [ "spacy",]
Expand Down
6 changes: 5 additions & 1 deletion rasa/cli/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def interactive(args: argparse.Namespace) -> None:
"data or a model containing core data."
)

zipped_model = train.train_core(args) if args.core_only else train.train(args)
zipped_model = (
train.run_core_training(args)
if args.core_only
else train.run_training(args)
)
if not zipped_model:
rasa.shared.utils.cli.print_error_and_exit(
"Could not train an initial model. Either pass paths "
Expand Down
21 changes: 10 additions & 11 deletions rasa/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import rasa.cli.arguments.train as train_arguments

import rasa.cli.utils
import rasa.utils.common
from rasa.core.train import do_compare_training
from rasa.shared.utils.cli import print_error
from rasa.shared.constants import (
CONFIG_MANDATORY_KEYS_CORE,
Expand All @@ -16,7 +18,6 @@
DEFAULT_DOMAIN_PATH,
DEFAULT_DATA_PATH,
)
import rasa.utils.common


def add_subparser(
Expand Down Expand Up @@ -45,23 +46,23 @@ def add_subparser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
help="Trains a Rasa Core model using your stories.",
)
train_core_parser.set_defaults(func=train_core)
train_core_parser.set_defaults(func=run_core_training)

train_nlu_parser = train_subparsers.add_parser(
"nlu",
parents=parents,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
help="Trains a Rasa NLU model using your NLU data.",
)
train_nlu_parser.set_defaults(func=train_nlu)
train_nlu_parser.set_defaults(func=run_nlu_training)

train_parser.set_defaults(func=lambda args: train(args, can_exit=True))
train_parser.set_defaults(func=lambda args: run_training(args, can_exit=True))

train_arguments.set_train_core_arguments(train_core_parser)
train_arguments.set_train_nlu_arguments(train_nlu_parser)


def train(args: argparse.Namespace, can_exit: bool = False) -> Optional[Text]:
def run_training(args: argparse.Namespace, can_exit: bool = False) -> Optional[Text]:
"""Trains a model.
Args:
Expand All @@ -72,7 +73,7 @@ def train(args: argparse.Namespace, can_exit: bool = False) -> Optional[Text]:
Returns:
Path to a trained model or `None` if training was not successful.
"""
import rasa
from rasa import train as train_all

domain = rasa.cli.utils.get_validated_path(
args.domain, "domain", DEFAULT_DOMAIN_PATH, none_is_valid=True
Expand All @@ -87,7 +88,7 @@ def train(args: argparse.Namespace, can_exit: bool = False) -> Optional[Text]:
for f in args.data
]

training_result = rasa.train(
training_result = train_all(
domain=domain,
config=config,
training_files=training_files,
Expand Down Expand Up @@ -117,7 +118,7 @@ def _model_for_finetuning(args: argparse.Namespace) -> Optional[Text]:
return args.finetune


def train_core(
def run_core_training(
args: argparse.Namespace, train_path: Optional[Text] = None
) -> Optional[Text]:
"""Trains a Rasa Core model only.
Expand Down Expand Up @@ -161,14 +162,12 @@ def train_core(
finetuning_epoch_fraction=args.epoch_fraction,
)
else:
from rasa.core.train import do_compare_training

rasa.utils.common.run_in_loop(
do_compare_training(args, story_file, additional_arguments)
)


def train_nlu(
def run_nlu_training(
args: argparse.Namespace, train_path: Optional[Text] = None
) -> Optional[Text]:
"""Trains an NLU model.
Expand Down
14 changes: 10 additions & 4 deletions rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,17 +778,20 @@ def has_user_affirmed(tracker: "DialogueStateTracker") -> bool:
def _revert_affirmation_events(tracker: "DialogueStateTracker") -> List[Event]:
revert_events = _revert_single_affirmation_events()

last_user_event = tracker.get_last_event_for(UserUttered)
last_user_event = copy.deepcopy(last_user_event)
last_user_event.parse_data["intent"]["confidence"] = 1.0

# User affirms the rephrased intent
rephrased_intent = tracker.last_executed_action_has(
name=ACTION_DEFAULT_ASK_REPHRASE_NAME, skip=1
)
if rephrased_intent:
revert_events += _revert_rephrasing_events()

last_user_event = tracker.get_last_event_for(UserUttered)
if not last_user_event:
raise TypeError("Cannot find last event to revert to.")

last_user_event = copy.deepcopy(last_user_event)
last_user_event.parse_data["intent"]["confidence"] = 1.0

return revert_events + [last_user_event]


Expand All @@ -804,6 +807,9 @@ def _revert_single_affirmation_events() -> List[Event]:

def _revert_successful_rephrasing(tracker) -> List[Event]:
last_user_event = tracker.get_last_event_for(UserUttered)
if not last_user_event:
raise TypeError("Cannot find last event to revert to.")

last_user_event = copy.deepcopy(last_user_event)
return _revert_rephrasing_events() + [last_user_event]

Expand Down
20 changes: 16 additions & 4 deletions rasa/core/actions/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,20 @@ def get_slot_to_fill(self, tracker: "DialogueStateTracker") -> Optional[str]:
)

def extract_requested_slot(
self, tracker: "DialogueStateTracker", domain: Domain
self, tracker: "DialogueStateTracker", domain: Domain, slot_to_fill: Text,
) -> Dict[Text, Any]:
"""Extracts the value of requested slot from a user input else return `None`."""
slot_to_fill = self.get_slot_to_fill(tracker)
"""Extract the value of requested slot from a user input else return `None`.
Args:
tracker: a DialogueStateTracker instance
domain: the current domain
slot_to_fill: the name of the slot to fill
Returns:
a dictionary with one key being the name of the slot to fill
and its value being the slot value, or an empty dictionary
if no slot value was found.
"""
logger.debug(f"Trying to extract requested slot '{slot_to_fill}' ...")

# get mapping for requested slot
Expand Down Expand Up @@ -477,7 +487,9 @@ async def validate(
# extract requested slot
slot_to_fill = self.get_slot_to_fill(tracker)
if slot_to_fill:
slot_values.update(self.extract_requested_slot(tracker, domain))
slot_values.update(
self.extract_requested_slot(tracker, domain, slot_to_fill)
)

validation_events = await self.validate_slots(
slot_values, tracker, domain, output_channel, nlg
Expand Down
9 changes: 8 additions & 1 deletion rasa/core/actions/two_stage_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,14 @@ def _second_affirmation_failed(tracker: DialogueStateTracker) -> bool:


def _message_clarification(tracker: DialogueStateTracker) -> List[Event]:
clarification = copy.deepcopy(tracker.latest_message)
latest_message = tracker.latest_message
if not latest_message:
raise TypeError(
"Cannot issue message clarification because "
"latest message is not on tracker."
)

clarification = copy.deepcopy(latest_message)
clarification.parse_data["intent"]["confidence"] = 1.0
clarification.timestamp = time.time()
return [ActionExecuted(ACTION_LISTEN_NAME), clarification]
7 changes: 2 additions & 5 deletions rasa/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,12 +523,9 @@ async def handle_message(
**kwargs,
) -> Optional[List[Dict[Text, Any]]]:
"""Handle a single message."""

def noop(_: Any) -> None:
logger.info("Ignoring message as there is no agent to handle it.")

if not self.is_ready():
return noop(message)
logger.info("Ignoring message as there is no agent to handle it.")
return None

processor = self.create_processor(message_preprocessor)

Expand Down
2 changes: 1 addition & 1 deletion rasa/core/channels/hangouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ async def _persist_message(self, message: Dict) -> None:
new_messages = self._combine_cards(text_card, message)

elif msg_new == "text":
new_messages = {"text": message.get("text")}
new_messages = {"text": message["text"]}
else:
new_messages = message

Expand Down
17 changes: 11 additions & 6 deletions rasa/core/channels/rasa_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import logging
from sanic.exceptions import abort
import jwt
import jwt.exceptions

from rasa.core import constants
from rasa.core.channels.channel import InputChannel
from rasa.core.channels.rest import RestInput
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
from rasa.core.exceptions import ChannelConfigError
from sanic.request import Request

logger = logging.getLogger(__name__)
Expand All @@ -34,8 +36,9 @@ def from_credentials(cls, credentials: Optional[Dict[Text, Any]]) -> InputChanne
return cls(credentials.get("url"))

def __init__(self, url: Optional[Text]) -> None:
"""Initialise the channel with attributes."""
self.base_url = url
self.jwt_key = None
self.jwt_key: Optional[Text] = None
self.jwt_algorithm = None

async def _fetch_public_key(self) -> None:
Expand Down Expand Up @@ -71,6 +74,12 @@ async def _fetch_public_key(self) -> None:
)

def _decode_jwt(self, bearer_token: Text) -> Dict:
if self.jwt_key is None:
raise ChannelConfigError(
"JWT public key is `None`. This is likely caused "
"by an error when retrieving the public key from Rasa X."
)

authorization_header_value = bearer_token.replace(
constants.BEARER_TOKEN_PREFIX, ""
)
Expand All @@ -82,19 +91,15 @@ async def _decode_bearer_token(self, bearer_token: Text) -> Optional[Dict]:
if self.jwt_key is None:
await self._fetch_public_key()

# noinspection PyBroadException
try:
return self._decode_jwt(bearer_token)
except jwt.exceptions.InvalidSignatureError:
except jwt.InvalidSignatureError:
logger.error("JWT public key invalid, fetching new one.")
await self._fetch_public_key()
return self._decode_jwt(bearer_token)
except Exception:
logger.exception("Failed to decode bearer token.")

async def _extract_sender(self, req: Request) -> Optional[Text]:
"""Fetch user from the Rasa X Admin API."""

jwt_payload = None
if req.headers.get("Authorization"):
jwt_payload = await self._decode_bearer_token(req.headers["Authorization"])
Expand Down
6 changes: 6 additions & 0 deletions rasa/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class UnsupportedDialogueModelError(RasaCoreException):
"""

def __init__(self, message: Text, model_version: Optional[Text] = None) -> None:
"""Initialize message and model_version attributes."""
self.message = message
self.model_version = model_version
super(UnsupportedDialogueModelError, self).__init__()
Expand All @@ -27,5 +28,10 @@ class AgentNotReady(RasaCoreException):
will be thrown."""

def __init__(self, message: Text) -> None:
"""Initialize message attribute."""
self.message = message
super(AgentNotReady, self).__init__()


class ChannelConfigError(RasaCoreException):
"""Raised if a channel is not configured correctly."""
7 changes: 5 additions & 2 deletions rasa/core/featurizers/single_state_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,15 @@ def convert_to_dict(feature_states: List[Text]) -> Dict[Text, int]:
def _state_features_for_attribute(
self, sub_state: SubState, attribute: Text
) -> Dict[Text, int]:
# FIXME: the code below is not type-safe, but fixing it
# would require more refactoring, for instance using
# data classes in our states
if attribute in {INTENT, ACTION_NAME}:
return {sub_state[attribute]: 1}
return {sub_state[attribute]: 1} # type: ignore[dict-item]
elif attribute == ENTITIES:
return {entity: 1 for entity in sub_state.get(ENTITIES, [])}
elif attribute == ACTIVE_LOOP:
return {sub_state["name"]: 1}
return {sub_state["name"]: 1} # type: ignore[dict-item]
elif attribute == SLOTS:
return {
f"{slot_name}_{i}": value
Expand Down
3 changes: 1 addition & 2 deletions rasa/core/policies/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,7 @@ def load(cls, path: Union[Text, Path], **kwargs: Any) -> "Policy":
)
return cls()

@staticmethod
def _default_predictions(domain: Domain) -> List[float]:
def _default_predictions(self, domain: Domain) -> List[float]:
"""Creates a list of zeros.
Args:
Expand Down
8 changes: 4 additions & 4 deletions rasa/core/policies/rule_policy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import List, Dict, Text, Optional, Any, Set, TYPE_CHECKING, Tuple
from typing import Any, List, Dict, Text, Optional, Set, Tuple, TYPE_CHECKING, Union

from tqdm import tqdm
import numpy as np
Expand Down Expand Up @@ -582,9 +582,7 @@ def train(

# only consider original trackers (no augmented ones)
training_trackers = [
t
for t in training_trackers
if not hasattr(t, "is_augmented") or not t.is_augmented
t for t in training_trackers if not getattr(t, "is_augmented", False)
]
# only use trackers from rule-based training data
rule_trackers = [t for t in training_trackers if t.is_rule_tracker]
Expand Down Expand Up @@ -772,6 +770,8 @@ def _find_action_from_loop_happy_path(
)
return ACTION_LISTEN_NAME

return None

def _find_action_from_rules(
self,
tracker: DialogueStateTracker,
Expand Down
Loading

0 comments on commit f7490fe

Please sign in to comment.