Skip to content

Commit

Permalink
Added a bunch of type definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
Joel Collins committed Nov 24, 2020
1 parent 69370c0 commit 3104d86
Show file tree
Hide file tree
Showing 21 changed files with 310 additions and 238 deletions.
34 changes: 33 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pylint = "^2.6.0"
sphinx = "^3.2.1"
sphinx-autoapi = "^1.4.0"
sphinx-rtd-theme = "^0.5.0"
mypy = "^0.790"

[tool.black]
exclude = '(\.eggs|\.git|\.venv)'
Expand Down
11 changes: 6 additions & 5 deletions src/labthings/actions/pool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import threading
from functools import wraps

from typing import Dict

from ..deque import Deque
from .thread import ActionThread
Expand Down Expand Up @@ -41,7 +42,7 @@ def spawn(self, action: str, function, *args, **kwargs):
self.start(thread)
return thread

def kill(self, timeout=5):
def kill(self, timeout: int = 5):
"""
:param timeout: (Default value = 5)
Expand Down Expand Up @@ -73,7 +74,7 @@ def states(self):
"""
return {str(t.id): t.state for t in self.threads}

def to_dict(self):
def to_dict(self) -> Dict[str, ActionThread]:
"""
Expand All @@ -84,8 +85,8 @@ def to_dict(self):
"""
return {str(t.id): t for t in self.threads}

def get(self, task_id):
return self.to_dict.get(task_id, None)
def get(self, task_id: str):
return self.to_dict().get(task_id, None)

def discard_id(self, task_id):
"""
Expand Down
76 changes: 38 additions & 38 deletions src/labthings/actions/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from flask import copy_current_request_context, has_request_context, request
from werkzeug.exceptions import BadRequest

from typing import Optional, Iterable, Dict, Any, Callable

from ..deque import LockableDeque
from ..utilities import TimeoutTracker

Expand All @@ -25,12 +27,12 @@ class ActionThread(threading.Thread):

def __init__(
self,
action,
target=None,
name=None,
args=None,
kwargs=None,
daemon=True,
action: str,
target: Optional[Callable] = None,
name: Optional[str] = None,
args: Optional[Iterable[Any]] = None,
kwargs: Optional[Dict[str, Any]] = None,
daemon: bool = True,
default_stop_timeout: int = 5,
log_len: int = 100,
):
Expand All @@ -39,34 +41,32 @@ def __init__(
group=None,
target=target,
name=name,
args=args,
kwargs=kwargs,
args=args or (),
kwargs=kwargs or {},
daemon=daemon,
)

# Safely populate missing arguments
args = args or ()
kwargs = kwargs or {}

# Action resource corresponding to this action object
self.action = action

# A UUID for the ActionThread (not the same as the threading.Thread ident)
self._ID = uuid.uuid4() # Task ID

# Event to track if the task has started
self.started = threading.Event()
self.started: threading.Event = threading.Event()
# Event to track if the user has requested stop
self.stopping = threading.Event()
self.default_stop_timeout = default_stop_timeout
self.stopping: threading.Event = threading.Event()
self.default_stop_timeout: int = default_stop_timeout

# Make _target, _args, and _kwargs available to the subclass
self._target = target
self._args = args
self._kwargs = kwargs
self._target: Optional[Callable] = target
self._args: Iterable[Any] = args or ()
self._kwargs: Dict[str, Any] = kwargs or {}

# Nice string representation of target function
self.target_string = f"{self._target}(args={self._args}, kwargs={self._kwargs})"
self.target_string: str = (
f"{self._target}(args={self._args}, kwargs={self._kwargs})"
)

# copy_current_request_context allows threads to access flask current_app
if has_request_context():
Expand All @@ -82,14 +82,14 @@ def __init__(

# Private state properties
self._status: str = "pending" # Task status
self._return_value = None # Return value
self._request_time = datetime.datetime.now()
self._start_time = None # Task start time
self._end_time = None # Task end time
self._return_value: Optional[Any] = None # Return value
self._request_time: datetime.datetime = datetime.datetime.now()
self._start_time: Optional[datetime.datetime] = None # Task start time
self._end_time: Optional[datetime.datetime] = None # Task end time

# Public state properties
self.progress: int = None # Percent progress of the task
self.data = {} # Dictionary of custom data added during the task
self.progress: Optional[int] = None # Percent progress of the task
self.data: dict = {} # Dictionary of custom data added during the task
self._log = LockableDeque(
None, log_len
) # The log will hold dictionary objects with log information
Expand All @@ -100,14 +100,14 @@ def __init__(
) # Lock obtained while self._target is running

@property
def id(self):
def id(self) -> uuid.UUID:
"""
UUID for the thread. Note this not the same as the native thread ident.
"""
return self._ID

@property
def output(self):
def output(self) -> Any:
"""
Return value of the Action function. If the Action is still running, returns None.
"""
Expand All @@ -119,7 +119,7 @@ def log(self):
return list(logdeque)

@property
def status(self):
def status(self) -> str:
"""
Current running status of the thread.
Expand All @@ -136,19 +136,19 @@ def status(self):
return self._status

@property
def dead(self):
def dead(self) -> bool:
"""
Has the thread finished, by any means (return, exception, termination).
"""
return not self.is_alive()

@property
def stopped(self):
def stopped(self) -> bool:
"""Has the thread been cancelled"""
return self.stopping.is_set()

@property
def cancelled(self):
def cancelled(self) -> bool:
"""Alias of `stopped`"""
return self.stopped

Expand All @@ -162,7 +162,7 @@ def update_progress(self, progress: int):
# Update progress of the task
self.progress = progress

def update_data(self, data: dict):
def update_data(self, data: Dict[Any, Any]):
"""
:param data: dict:
Expand All @@ -183,7 +183,7 @@ def run(self):
# an argument that has a member that points to the thread.
del self._target, self._args, self._kwargs

def _thread_proc(self, f):
def _thread_proc(self, f: Callable):
"""Wraps the target function to handle recording `status` and `return` to `state`.
Happens inside the task thread.
Expand Down Expand Up @@ -228,7 +228,7 @@ def wrapped(*args, **kwargs):

return wrapped

def get(self, block=True, timeout=None):
def get(self, block: bool = True, timeout: Optional[int] = None):
"""Start waiting for the task to finish before returning
:param block: (Default value = True)
Expand Down Expand Up @@ -275,7 +275,7 @@ def _async_raise(self, exc_type):
% (exc_type, self.name, self.ident, result)
)

def _is_thread_proc_running(self):
def _is_thread_proc_running(self) -> bool:
"""Test if thread funtion (_thread_proc) is running,
by attemtping to acquire the lock _thread_proc acquires at runtime.
Expand All @@ -285,13 +285,13 @@ def _is_thread_proc_running(self):
:rtype: bool
"""
could_acquire = self._running_lock.acquire(0)
could_acquire = self._running_lock.acquire(False)
if could_acquire:
self._running_lock.release()
return False
return True

def terminate(self, exception=ActionKilledException):
def terminate(self, exception=ActionKilledException) -> bool:
"""
:param exception: (Default value = ActionKilledException)
Expand All @@ -314,7 +314,7 @@ def terminate(self, exception=ActionKilledException):
self.progress = None
return True

def stop(self, timeout=None, exception=ActionKilledException):
def stop(self, timeout=None, exception=ActionKilledException) -> bool:
"""Sets the threads internal stopped event, waits for timeout seconds for the
thread to stop nicely, then forcefully kills the thread.
Expand Down
4 changes: 2 additions & 2 deletions src/labthings/apispec/plugins.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re

from apispec import BasePlugin
from apispec.ext.marshmallow import MarshmallowPlugin as _MarshmallowPlugin
from apispec import BasePlugin # type: ignore
from apispec.ext.marshmallow import MarshmallowPlugin as _MarshmallowPlugin # type: ignore
from apispec.ext.marshmallow import OpenAPIConverter
from flask.views import http_method_funcs

Expand Down
30 changes: 17 additions & 13 deletions src/labthings/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from flask import url_for

from typing import List, Dict, Callable

from .utilities import camel_to_snake, get_docstring, snake_to_spine
from .views.builder import static_from

Expand All @@ -27,27 +29,27 @@ def __init__(
static_url_path="/static",
static_folder=None,
):
self._views = (
self._views: dict = (
{}
) # Key: Full, Python-safe ID. Val: Original rule, and view class
self._rules = {} # Key: Original rule. Val: View class
self._meta = {} # Extra metadata to add to the extension description
self._rules: dict = {} # Key: Original rule. Val: View class
self._meta: dict = {} # Extra metadata to add to the extension description

self._on_registers = (
[]
) # List of dictionaries of functions to run on registration
self._on_registers: List[
Dict
] = [] # List of dictionaries of functions to run on registration

self._on_components = (
[]
) # List of dictionaries of functions to run as components are added
self._on_components: List[
Dict
] = [] # List of dictionaries of functions to run as components are added

self._cls = str(self) # String description of extension instance

self._name = name
self.description = description or get_docstring(self)
self.version = str(version)

self.methods = {}
self.methods: Dict[str, Callable] = {}

self.static_view_class = static_from(static_folder)
self.add_view(
Expand Down Expand Up @@ -169,7 +171,7 @@ def _name_uri_safe(self):
""" """
return snake_to_spine(self._name_python_safe)

def add_method(self, method, method_name):
def add_method(self, method: Callable, method_name: str):
"""
:param method:
Expand Down Expand Up @@ -240,15 +242,17 @@ def find_extensions_in_file(extension_path: str, module_name="extensions") -> li
sys.modules[spec.name] = mod

try:
spec.loader.exec_module(mod)
spec.loader.exec_module(mod) # type: ignore
except Exception: # skipcq: PYL-W0703
logging.error(
f"Exception in extension path {extension_path}: \n{traceback.format_exc()}"
)
return []
else:
if hasattr(mod, "__extensions__"):
return [getattr(mod, ext_name) for ext_name in mod.__extensions__]
return [
getattr(mod, ext_name) for ext_name in getattr(mod, "__extensions__")
]
else:
return find_instances_in_module(mod, BaseExtension)

Expand Down
1 change: 0 additions & 1 deletion src/labthings/find.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import weakref

from flask import current_app, url_for
Expand Down
Loading

0 comments on commit 3104d86

Please sign in to comment.