From d67b5b0ee3dc3d0c693718bd9dfc63f8bd46bb27 Mon Sep 17 00:00:00 2001 From: Joel Collins Date: Thu, 9 Jul 2020 15:38:56 +0100 Subject: [PATCH] Removed default task pool --- src/labthings/find.py | 15 ++++++---- src/labthings/labthing.py | 7 +++-- src/labthings/representations.py | 5 ++-- src/labthings/tasks/__init__.py | 22 ++++---------- src/labthings/tasks/pool.py | 40 ------------------------- src/labthings/tasks/thread.py | 4 --- src/labthings/view/__init__.py | 19 ++++++------ tests/conftest.py | 10 +++++++ tests/test_default_views.py | 31 ++++++++++--------- tests/test_tasks_pool.py | 51 +++++++++++++++++--------------- 10 files changed, 84 insertions(+), 120 deletions(-) diff --git a/src/labthings/find.py b/src/labthings/find.py index ba668ea0..74eb45cf 100644 --- a/src/labthings/find.py +++ b/src/labthings/find.py @@ -2,12 +2,15 @@ from flask import current_app, url_for import weakref +from werkzeug.local import LocalProxy + from .names import EXTENSION_NAME __all__ = [ "current_app", "url_for", "current_labthing", + "current_thing", "registered_extensions", "registered_components", "find_component", @@ -23,11 +26,10 @@ def current_labthing(app=None): # We use _get_current_object so that Task threads can still # reach the Flask app object. Just using current_app returns # a wrapper, which breaks it's use in Task threads - if not app: - try: - app = current_app._get_current_object() # skipcq: PYL-W0212 - except RuntimeError: - return None + try: + app = current_app._get_current_object() # skipcq: PYL-W0212 + except RuntimeError: + return None ext = app.extensions.get(EXTENSION_NAME, None) if isinstance(ext, weakref.ref): return ext() @@ -107,3 +109,6 @@ def find_extension(extension_name, labthing_instance=None): return labthing_instance.extensions[extension_name] else: return None + + +current_thing = LocalProxy(current_labthing) diff --git a/src/labthings/labthing.py b/src/labthings/labthing.py index 44be8464..e58e0935 100644 --- a/src/labthings/labthing.py +++ b/src/labthings/labthing.py @@ -16,12 +16,13 @@ from .httperrorhandler import SerializedExceptionHandler from .logging import LabThingLogger from .json.encoder import LabThingsJSONEncoder +from .representations import DEFAULT_REPRESENTATIONS from .apispec import MarshmallowPlugin, rule_to_apispec_path from .td import ThingDescription from .sockets import Sockets from .event import Event -from .tasks import Pool, change_default_pool +from .tasks import Pool from .view.builder import property_of, action_from @@ -61,7 +62,6 @@ def __init__( self.extensions = {} self.actions = Pool() # Pool of greenlets for Actions - change_default_pool(self.actions) self.events = {} @@ -94,6 +94,9 @@ def __init__( self.log_handler = LabThingLogger() logging.getLogger().addHandler(self.log_handler) + # Representation formatter map + self.representations = DEFAULT_REPRESENTATIONS + # API Spec self.spec = APISpec( title=self.title, diff --git a/src/labthings/representations.py b/src/labthings/representations.py index 72279363..fa34605d 100644 --- a/src/labthings/representations.py +++ b/src/labthings/representations.py @@ -1,4 +1,5 @@ from flask import make_response, current_app +from collections import OrderedDict from .json.encoder import LabThingsJSONEncoder, encode_json from .utilities import PY3 @@ -24,6 +25,4 @@ def output_json(data, code, headers=None): return resp -DEFAULT_REPRESENTATIONS = [ - ("application/json", output_json), -] +DEFAULT_REPRESENTATIONS = OrderedDict({"application/json": output_json,}) diff --git a/src/labthings/tasks/__init__.py b/src/labthings/tasks/__init__.py index a6fbf30a..6e1647b5 100644 --- a/src/labthings/tasks/__init__.py +++ b/src/labthings/tasks/__init__.py @@ -1,31 +1,19 @@ __all__ = [ "Pool", - "taskify", - "tasks", - "to_dict", - "states", "current_task", "update_task_progress", - "cleanup", - "discard_id", "update_task_data", - "default_pool", - "change_default_pool", + "TaskKillException", "ThreadTerminationError", ] from .pool import ( Pool, - tasks, - to_dict, - states, current_task, update_task_progress, - cleanup, - discard_id, update_task_data, - taskify, - default_pool, - change_default_pool, ) -from .thread import ThreadTerminationError +from .thread import TaskKillException + +# Legacy alias +ThreadTerminationError = TaskKillException diff --git a/src/labthings/tasks/pool.py b/src/labthings/tasks/pool.py index d4668e95..fcec2ed5 100644 --- a/src/labthings/tasks/pool.py +++ b/src/labthings/tasks/pool.py @@ -112,43 +112,3 @@ def update_task_data(data: dict): current_task().update_data(data) else: logging.info("Cannot update task data of __main__ thread. Skipping.") - - -# Main "taskify" functions - - -def taskify(f): - """ - A decorator that wraps the passed in function - and surpresses exceptions should one occur - """ - global default_pool - - @wraps(f) - def wrapped(*args, **kwargs): - task = default_pool.spawn( - f, *args, **kwargs - ) # Append to parent object's task list - return task - - return wrapped - - -# Create our default, protected, module-level task pool -default_pool = Pool() - -tasks = default_pool.tasks -to_dict = default_pool.to_dict -states = default_pool.states -cleanup = default_pool.cleanup -discard_id = default_pool.discard_id - - -def change_default_pool(new_default_pool: Pool): - global default_pool, tasks, to_dict, states, cleanup, discard_id - default_pool = new_default_pool - tasks = new_default_pool.tasks - to_dict = new_default_pool.to_dict - states = new_default_pool.states - cleanup = new_default_pool.cleanup - discard_id = new_default_pool.discard_id diff --git a/src/labthings/tasks/thread.py b/src/labthings/tasks/thread.py index dfc6fa41..75916a53 100644 --- a/src/labthings/tasks/thread.py +++ b/src/labthings/tasks/thread.py @@ -10,10 +10,6 @@ _LOG = logging.getLogger(__name__) -class ThreadTerminationError(SystemExit): - """Sibling of SystemExit, but specific to thread termination.""" - - class TaskKillException(Exception): """Sibling of SystemExit, but specific to thread termination.""" diff --git a/src/labthings/view/__init__.py b/src/labthings/view/__init__.py index d2c3f42e..77ef2dbf 100644 --- a/src/labthings/view/__init__.py +++ b/src/labthings/view/__init__.py @@ -3,17 +3,15 @@ from werkzeug.wrappers import Response as ResponseBase from werkzeug.exceptions import BadRequest -from collections import OrderedDict - from .args import use_args from .marshalling import marshal_with from ..utilities import unpack, get_docstring, get_summary, merge from ..representations import DEFAULT_REPRESENTATIONS -from ..find import current_labthing +from ..find import current_labthing, current_thing from ..event import PropertyStatusEvent from ..schema import Schema, ActionSchema, build_action_schema -from ..tasks import default_pool +from ..tasks import Pool from ..deque import Deque, resize_deque from ..json.schemas import schema_to_json from .. import fields @@ -46,8 +44,9 @@ def __init__(self, *args, **kwargs): MethodView.__init__(self, *args, **kwargs) # Set the default representations - # TODO: Inherit from parent LabThing. See original flask_restful implementation - self.representations = OrderedDict(DEFAULT_REPRESENTATIONS) + self.representations = ( + current_thing.representations if current_thing else DEFAULT_REPRESENTATIONS + ) @classmethod def get_apispec(cls): @@ -152,6 +151,7 @@ class ActionView(View): # Internal _cls_tags = {"actions"} _deque = Deque() # Action queue + _emergency_pool = Pool() def get(self): queue_schema = build_action_schema(self.schema, self.args)(many=True) @@ -247,8 +247,9 @@ def dispatch_request(self, *args, **kwargs): if self.schema: meth = marshal_with(self.schema)(meth) + # Try to find a pool on the current LabThing, but fall back to Views emergency pool + pool = current_thing.actions if current_thing else self._emergency_pool # Make a task out of the views `post` method - pool = current_labthing().actions if current_labthing else default_pool task = pool.spawn(meth, *args, **kwargs) # Keep a copy of the raw, unmarshalled JSON input in the task @@ -380,8 +381,8 @@ def dispatch_request(self, *args, **kwargs): self, "__name__", "unknown" ) - if current_labthing(): - current_labthing().message( + if current_thing: + current_thing.message( PropertyStatusEvent(property_name), property_value, ) diff --git a/tests/conftest.py b/tests/conftest.py index 88619b53..70057bf9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ from apispec.ext.marshmallow import MarshmallowPlugin from labthings.server.labthing import LabThing from labthings.server.view import View +from labthings.tasks import Pool from werkzeug.test import EnvironBuilder from flask.testing import FlaskClient @@ -309,3 +310,12 @@ def _foo(*args, **kwargs): return FakeWebsocket(*args, **kwargs) return _foo + + +@pytest.fixture +def task_pool(): + """ + Return a task pool + """ + + return Pool() diff --git a/tests/test_default_views.py b/tests/test_default_views.py index 3600f6a2..1eeae893 100644 --- a/tests/test_default_views.py +++ b/tests/test_default_views.py @@ -1,4 +1,4 @@ -from labthings import tasks +from labthings.find import current_labthing import gevent @@ -18,56 +18,55 @@ def test_extensions(thing_client): assert c.get("/extensions").json == [] -def test_tasks_list(thing_client): +def test_actions_list(thing_client): def task_func(): pass - task_obj = tasks.taskify(task_func)() + task_obj = current_labthing().actions.spawn(task_func) with thing_client as c: - response = c.get("/tasks").json + response = c.get("/actions").json ids = [task.get("id") for task in response] assert str(task_obj.id) in ids -def test_task_representation(thing_client): +def test_action_representation(thing_client): def task_func(): pass - task_obj = tasks.taskify(task_func)() + task_obj = current_labthing().actions.spawn(task_func) task_id = str(task_obj.id) with thing_client as c: - response = c.get(f"/tasks/{task_id}").json + response = c.get(f"/actions/{task_id}").json assert response -def test_task_representation_missing(thing_client): +def test_action_representation_missing(thing_client): with thing_client as c: - assert c.get("/tasks/missing_id").status_code == 404 + assert c.get("/actions/missing_id").status_code == 404 -def test_task_kill(thing_client): +def test_action_kill(thing_client): def task_func(): while True: gevent.sleep(0) - task_obj = tasks.taskify(task_func)() + task_obj = current_labthing().actions.spawn(task_func) task_id = str(task_obj.id) # Wait for task to start task_obj.started_event.wait() - assert task_id in tasks.to_dict() + assert task_id in current_labthing().actions.to_dict() # Send a DELETE request to terminate the task with thing_client as c: - response = c.delete(f"/tasks/{task_id}") - print(response.json) + response = c.delete(f"/actions/{task_id}") assert response.status_code == 200 # Test task was terminated assert task_obj._status == "terminated" -def test_task_kill_missing(thing_client): +def test_action_kill_missing(thing_client): with thing_client as c: - assert c.delete("/tasks/missing_id").status_code == 404 + assert c.delete("/actions/missing_id").status_code == 404 diff --git a/tests/test_tasks_pool.py b/tests/test_tasks_pool.py index 1e5ea96e..6657382c 100644 --- a/tests/test_tasks_pool.py +++ b/tests/test_tasks_pool.py @@ -3,28 +3,28 @@ import gevent -def test_taskify_without_context(): +def test_spawn_without_context(task_pool): def task_func(): pass - task_obj = tasks.taskify(task_func)() + task_obj = task_pool.spawn(task_func) assert isinstance(task_obj, gevent.Greenlet) -def test_taskify_with_context(app_ctx): +def test_spawn_with_context(app_ctx, task_pool): def task_func(): pass with app_ctx.test_request_context(): - task_obj = tasks.taskify(task_func)() + task_obj = task_pool.spawn(task_func) assert isinstance(task_obj, gevent.Greenlet) -def test_update_task_data(): +def test_update_task_data(task_pool): def task_func(): tasks.update_task_data({"key": "value"}) - task_obj = tasks.taskify(task_func)() + task_obj = task_pool.spawn(task_func) task_obj.join() assert task_obj.data == {"key": "value"} @@ -34,11 +34,11 @@ def test_update_task_data_main_thread(): tasks.update_task_data({"key": "value"}) -def test_update_task_progress(): +def test_update_task_progress(task_pool): def task_func(): tasks.update_task_progress(100) - task_obj = tasks.taskify(task_func)() + task_obj = task_pool.spawn(task_func) task_obj.join() assert task_obj.progress == 100 @@ -48,42 +48,45 @@ def test_update_task_progress_main_thread(): tasks.update_task_progress(100) -def test_tasks_list(): - assert all(isinstance(task_obj, gevent.Greenlet) for task_obj in tasks.tasks()) +def test_tasks_list(task_pool): + assert all( + isinstance(task_obj, gevent.Greenlet) for task_obj in task_pool.greenlets + ) -def test_tasks_dict(): +def test_tasks_dict(task_pool): assert all( - isinstance(task_obj, gevent.Greenlet) for task_obj in tasks.to_dict().values() + isinstance(task_obj, gevent.Greenlet) + for task_obj in task_pool.to_dict().values() ) - assert all(k == str(t.id) for k, t in tasks.to_dict().items()) + assert all(k == str(t.id) for k, t in task_pool.to_dict().items()) -def test_discard_id(): +def test_discard_id(task_pool): def task_func(): pass - task_obj = tasks.taskify(task_func)() - assert str(task_obj.id) in tasks.to_dict() + task_obj = task_pool.spawn(task_func) + assert str(task_obj.id) in task_pool.to_dict() task_obj.join() - tasks.discard_id(task_obj.id) - assert not str(task_obj.id) in tasks.to_dict() + task_pool.discard_id(task_obj.id) + assert not str(task_obj.id) in task_pool.to_dict() -def test_cleanup_task(): +def test_cleanup_task(task_pool): import time def task_func(): pass # Make sure at least 1 tasks is around - tasks.taskify(task_func)() + task_pool.spawn(task_func) # Wait for all tasks to finish - gevent.joinall(tasks.tasks()) + gevent.joinall(task_pool.greenlets) - assert len(tasks.tasks()) > 0 - tasks.cleanup() - assert len(tasks.tasks()) == 0 + assert len(task_pool.greenlets) > 0 + task_pool.cleanup() + assert len(task_pool.greenlets) == 0