Skip to content

Commit

Permalink
cleanup type error: [call-arg]
Browse files Browse the repository at this point in the history
  • Loading branch information
m-vdb committed Aug 31, 2020
1 parent 0da7e77 commit 3e4c55d
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 27 deletions.
3 changes: 1 addition & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ toml = "^0.10.0"
pep440-version-utils = "^0.3.0"
pydoc-markdown = "3.3.0.post1"
mypy = "^0.782"
typing-extensions = "^3.7.4"

[tool.poetry.extras]
spacy = [ "spacy",]
Expand Down
20 changes: 7 additions & 13 deletions rasa/cli/train.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import argparse
import asyncio
import os
from typing import List, Optional, Text, Dict
import rasa.cli.arguments.train as train_arguments

import rasa.train
import rasa.cli.arguments.train as train_arguments
from rasa.cli.utils import get_validated_path, missing_config_keys, print_error
from rasa.core.train import do_compare_training
from rasa.constants import (
DEFAULT_CONFIG_PATH,
DEFAULT_DATA_PATH,
Expand Down Expand Up @@ -52,8 +55,6 @@ def add_subparser(


def train(args: argparse.Namespace) -> Optional[Text]:
import rasa

domain = get_validated_path(
args.domain, "domain", DEFAULT_DOMAIN_PATH, none_is_valid=True
)
Expand All @@ -65,7 +66,7 @@ def train(args: argparse.Namespace) -> Optional[Text]:
for f in args.data
]

return rasa.train(
return rasa.train.train(
domain=domain,
config=config,
training_files=training_files,
Expand All @@ -81,9 +82,6 @@ def train(args: argparse.Namespace) -> Optional[Text]:
def train_core(
args: argparse.Namespace, train_path: Optional[Text] = None
) -> Optional[Text]:
from rasa.train import train_core
import asyncio

loop = asyncio.get_event_loop()
output = train_path or args.out

Expand All @@ -103,7 +101,7 @@ def train_core(

config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS_CORE)

return train_core(
return rasa.train.train_core(
domain=args.domain,
config=config,
stories=story_file,
Expand All @@ -113,8 +111,6 @@ def train_core(
additional_arguments=additional_arguments,
)
else:
from rasa.core.train import do_compare_training

loop.run_until_complete(
do_compare_training(args, story_file, additional_arguments)
)
Expand All @@ -123,16 +119,14 @@ def train_core(
def train_nlu(
args: argparse.Namespace, train_path: Optional[Text] = None
) -> Optional[Text]:
from rasa.train import train_nlu

output = train_path or args.out

config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS_NLU)
nlu_data = get_validated_path(
args.nlu, "nlu", DEFAULT_DATA_PATH, none_is_valid=True
)

return train_nlu(
return rasa.train.train_nlu(
config=config,
nlu_data=nlu_data,
output=output,
Expand Down
27 changes: 21 additions & 6 deletions rasa/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from functools import reduce, wraps
from inspect import isawaitable
from pathlib import Path
from typing import Any, Callable, List, Optional, Text, Union, Dict
from typing import Any, Callable, List, Optional, Text, Union, Dict, cast

from rasa.core.training.story_writer.yaml_story_writer import YAMLStoryWriter
from rasa.nlu.training_data.formats import RasaYAMLReader
Expand Down Expand Up @@ -52,8 +52,16 @@

if typing.TYPE_CHECKING:
from ssl import SSLContext
from typing_extensions import Protocol
from rasa.core.processor import MessageProcessor

class SanicView(Protocol):
def __call__(
self, request: Request, *args: Any, **kwargs: Any
) -> response.BaseHTTPResponse:
...


logger = logging.getLogger(__name__)

JSON_CONTENT_TYPE = "application/json"
Expand Down Expand Up @@ -123,7 +131,7 @@ def decorated(*args, **kwargs):
def requires_auth(app: Sanic, token: Optional[Text] = None) -> Callable[[Any], Any]:
"""Wraps a request handler with token authentication."""

def decorator(f: Callable[[Any, Any], Any]) -> Callable[[Any, Any], Any]:
def decorator(f: "SanicView") -> "SanicView":
def conversation_id_from_args(args: Any, kwargs: Any) -> Optional[Text]:
argnames = common_utils.arguments_of(f)

Expand Down Expand Up @@ -153,7 +161,9 @@ def sufficient_scope(request, *args: Any, **kwargs: Any) -> Optional[bool]:
return False

@wraps(f)
async def decorated(request: Request, *args: Any, **kwargs: Any) -> Any:
async def decorated(
request: Request, *args: Any, **kwargs: Any
) -> response.BaseHTTPResponse:

provided = request.args.get("token", None)

Expand Down Expand Up @@ -227,7 +237,7 @@ async def get_tracker(
_validate_tracker(tracker, conversation_id)

# `_validate_tracker` ensures we can't return `None` so `Optional` is not needed
return tracker # pytype: disable=bad-return-type
return cast(DialogueStateTracker, tracker)


def _validate_tracker(
Expand Down Expand Up @@ -631,7 +641,7 @@ async def execute_action(request: Request, conversation_id: Text):
tracker = await get_tracker(app.agent.create_processor(), conversation_id)
state = tracker.current_state(verbosity)

response_body = {"tracker": state}
response_body: Dict[Text, Any] = {"tracker": state}

if isinstance(output_channel, CollectingOutputChannel):
response_body["messages"] = output_channel.messages
Expand Down Expand Up @@ -685,7 +695,7 @@ async def trigger_intent(request: Request, conversation_id: Text) -> HTTPRespons

state = tracker.current_state(verbosity)

response_body = {"tracker": state}
response_body: Dict[Text, Any] = {"tracker": state}

if isinstance(output_channel, CollectingOutputChannel):
response_body["messages"] = output_channel.messages
Expand Down Expand Up @@ -870,6 +880,11 @@ async def evaluate_intents(request: Request) -> HTTPResponse:
model_directory = eval_agent.model_directory
_, nlu_model = model.get_model_subdirectories(model_directory)

if nlu_model is None:
raise ErrorResponse(
500, "TestingError", f"Missing NLU model directory.",
)

try:
evaluation = run_evaluation(data_path, nlu_model, disable_plotting=True)
return response.json(evaluation)
Expand Down
6 changes: 3 additions & 3 deletions rasa/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_core(
"to train a NLU model first, e.g. using `rasa train`."
)

from rasa.core.test import test
from rasa.core.test import test as core_test

kwargs = utils.minimal_kwargs(additional_arguments, test, ["stories", "agent"])

Expand All @@ -157,11 +157,11 @@ def test_core(
def _test_core(
stories: Optional[Text], agent: "Agent", output_directory: Text, **kwargs: Any
) -> None:
from rasa.core.test import test
from rasa.core.test import test as core_test

loop = asyncio.get_event_loop()
loop.run_until_complete(
test(stories, agent, out_directory=output_directory, **kwargs)
core_test(stories, agent, out_directory=output_directory, **kwargs)
)


Expand Down
3 changes: 0 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,6 @@ ignore_errors = True
[mypy-rasa.run]
ignore_errors = True

[mypy-rasa.server]
ignore_errors = True

[mypy-rasa.test]
ignore_errors = True

Expand Down

0 comments on commit 3e4c55d

Please sign in to comment.