Skip to content

Commit

Permalink
Removed default task pool
Browse files Browse the repository at this point in the history
  • Loading branch information
Joel Collins committed Jul 9, 2020
1 parent 53ea2bf commit d67b5b0
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 120 deletions.
15 changes: 10 additions & 5 deletions src/labthings/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
from flask import current_app, url_for
import weakref

from werkzeug.local import LocalProxy

from .names import EXTENSION_NAME

__all__ = [
"current_app",
"url_for",
"current_labthing",
"current_thing",
"registered_extensions",
"registered_components",
"find_component",
Expand All @@ -23,11 +26,10 @@ def current_labthing(app=None):
# We use _get_current_object so that Task threads can still
# reach the Flask app object. Just using current_app returns
# a wrapper, which breaks it's use in Task threads
if not app:
try:
app = current_app._get_current_object() # skipcq: PYL-W0212
except RuntimeError:
return None
try:
app = current_app._get_current_object() # skipcq: PYL-W0212
except RuntimeError:
return None
ext = app.extensions.get(EXTENSION_NAME, None)
if isinstance(ext, weakref.ref):
return ext()
Expand Down Expand Up @@ -107,3 +109,6 @@ def find_extension(extension_name, labthing_instance=None):
return labthing_instance.extensions[extension_name]
else:
return None


current_thing = LocalProxy(current_labthing)
7 changes: 5 additions & 2 deletions src/labthings/labthing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
from .httperrorhandler import SerializedExceptionHandler
from .logging import LabThingLogger
from .json.encoder import LabThingsJSONEncoder
from .representations import DEFAULT_REPRESENTATIONS
from .apispec import MarshmallowPlugin, rule_to_apispec_path
from .td import ThingDescription
from .sockets import Sockets
from .event import Event

from .tasks import Pool, change_default_pool
from .tasks import Pool

from .view.builder import property_of, action_from

Expand Down Expand Up @@ -61,7 +62,6 @@ def __init__(
self.extensions = {}

self.actions = Pool() # Pool of greenlets for Actions
change_default_pool(self.actions)

self.events = {}

Expand Down Expand Up @@ -94,6 +94,9 @@ def __init__(
self.log_handler = LabThingLogger()
logging.getLogger().addHandler(self.log_handler)

# Representation formatter map
self.representations = DEFAULT_REPRESENTATIONS

# API Spec
self.spec = APISpec(
title=self.title,
Expand Down
5 changes: 2 additions & 3 deletions src/labthings/representations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from flask import make_response, current_app
from collections import OrderedDict

from .json.encoder import LabThingsJSONEncoder, encode_json
from .utilities import PY3
Expand All @@ -24,6 +25,4 @@ def output_json(data, code, headers=None):
return resp


DEFAULT_REPRESENTATIONS = [
("application/json", output_json),
]
DEFAULT_REPRESENTATIONS = OrderedDict({"application/json": output_json,})
22 changes: 5 additions & 17 deletions src/labthings/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,19 @@
__all__ = [
"Pool",
"taskify",
"tasks",
"to_dict",
"states",
"current_task",
"update_task_progress",
"cleanup",
"discard_id",
"update_task_data",
"default_pool",
"change_default_pool",
"TaskKillException",
"ThreadTerminationError",
]

from .pool import (
Pool,
tasks,
to_dict,
states,
current_task,
update_task_progress,
cleanup,
discard_id,
update_task_data,
taskify,
default_pool,
change_default_pool,
)
from .thread import ThreadTerminationError
from .thread import TaskKillException

# Legacy alias
ThreadTerminationError = TaskKillException
40 changes: 0 additions & 40 deletions src/labthings/tasks/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,43 +112,3 @@ def update_task_data(data: dict):
current_task().update_data(data)
else:
logging.info("Cannot update task data of __main__ thread. Skipping.")


# Main "taskify" functions


def taskify(f):
"""
A decorator that wraps the passed in function
and surpresses exceptions should one occur
"""
global default_pool

@wraps(f)
def wrapped(*args, **kwargs):
task = default_pool.spawn(
f, *args, **kwargs
) # Append to parent object's task list
return task

return wrapped


# Create our default, protected, module-level task pool
default_pool = Pool()

tasks = default_pool.tasks
to_dict = default_pool.to_dict
states = default_pool.states
cleanup = default_pool.cleanup
discard_id = default_pool.discard_id


def change_default_pool(new_default_pool: Pool):
global default_pool, tasks, to_dict, states, cleanup, discard_id
default_pool = new_default_pool
tasks = new_default_pool.tasks
to_dict = new_default_pool.to_dict
states = new_default_pool.states
cleanup = new_default_pool.cleanup
discard_id = new_default_pool.discard_id
4 changes: 0 additions & 4 deletions src/labthings/tasks/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
_LOG = logging.getLogger(__name__)


class ThreadTerminationError(SystemExit):
"""Sibling of SystemExit, but specific to thread termination."""


class TaskKillException(Exception):
"""Sibling of SystemExit, but specific to thread termination."""

Expand Down
19 changes: 10 additions & 9 deletions src/labthings/view/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@
from werkzeug.wrappers import Response as ResponseBase
from werkzeug.exceptions import BadRequest

from collections import OrderedDict

from .args import use_args
from .marshalling import marshal_with

from ..utilities import unpack, get_docstring, get_summary, merge
from ..representations import DEFAULT_REPRESENTATIONS
from ..find import current_labthing
from ..find import current_labthing, current_thing
from ..event import PropertyStatusEvent
from ..schema import Schema, ActionSchema, build_action_schema
from ..tasks import default_pool
from ..tasks import Pool
from ..deque import Deque, resize_deque
from ..json.schemas import schema_to_json
from .. import fields
Expand Down Expand Up @@ -46,8 +44,9 @@ def __init__(self, *args, **kwargs):
MethodView.__init__(self, *args, **kwargs)

# Set the default representations
# TODO: Inherit from parent LabThing. See original flask_restful implementation
self.representations = OrderedDict(DEFAULT_REPRESENTATIONS)
self.representations = (
current_thing.representations if current_thing else DEFAULT_REPRESENTATIONS
)

@classmethod
def get_apispec(cls):
Expand Down Expand Up @@ -152,6 +151,7 @@ class ActionView(View):
# Internal
_cls_tags = {"actions"}
_deque = Deque() # Action queue
_emergency_pool = Pool()

def get(self):
queue_schema = build_action_schema(self.schema, self.args)(many=True)
Expand Down Expand Up @@ -247,8 +247,9 @@ def dispatch_request(self, *args, **kwargs):
if self.schema:
meth = marshal_with(self.schema)(meth)

# Try to find a pool on the current LabThing, but fall back to Views emergency pool
pool = current_thing.actions if current_thing else self._emergency_pool
# Make a task out of the views `post` method
pool = current_labthing().actions if current_labthing else default_pool
task = pool.spawn(meth, *args, **kwargs)

# Keep a copy of the raw, unmarshalled JSON input in the task
Expand Down Expand Up @@ -380,8 +381,8 @@ def dispatch_request(self, *args, **kwargs):
self, "__name__", "unknown"
)

if current_labthing():
current_labthing().message(
if current_thing:
current_thing.message(
PropertyStatusEvent(property_name), property_value,
)

Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from apispec.ext.marshmallow import MarshmallowPlugin
from labthings.server.labthing import LabThing
from labthings.server.view import View
from labthings.tasks import Pool

from werkzeug.test import EnvironBuilder
from flask.testing import FlaskClient
Expand Down Expand Up @@ -309,3 +310,12 @@ def _foo(*args, **kwargs):
return FakeWebsocket(*args, **kwargs)

return _foo


@pytest.fixture
def task_pool():
"""
Return a task pool
"""

return Pool()
31 changes: 15 additions & 16 deletions tests/test_default_views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from labthings import tasks
from labthings.find import current_labthing

import gevent

Expand All @@ -18,56 +18,55 @@ def test_extensions(thing_client):
assert c.get("/extensions").json == []


def test_tasks_list(thing_client):
def test_actions_list(thing_client):
def task_func():
pass

task_obj = tasks.taskify(task_func)()
task_obj = current_labthing().actions.spawn(task_func)

with thing_client as c:
response = c.get("/tasks").json
response = c.get("/actions").json
ids = [task.get("id") for task in response]
assert str(task_obj.id) in ids


def test_task_representation(thing_client):
def test_action_representation(thing_client):
def task_func():
pass

task_obj = tasks.taskify(task_func)()
task_obj = current_labthing().actions.spawn(task_func)
task_id = str(task_obj.id)

with thing_client as c:
response = c.get(f"/tasks/{task_id}").json
response = c.get(f"/actions/{task_id}").json
assert response


def test_task_representation_missing(thing_client):
def test_action_representation_missing(thing_client):
with thing_client as c:
assert c.get("/tasks/missing_id").status_code == 404
assert c.get("/actions/missing_id").status_code == 404


def test_task_kill(thing_client):
def test_action_kill(thing_client):
def task_func():
while True:
gevent.sleep(0)

task_obj = tasks.taskify(task_func)()
task_obj = current_labthing().actions.spawn(task_func)
task_id = str(task_obj.id)

# Wait for task to start
task_obj.started_event.wait()
assert task_id in tasks.to_dict()
assert task_id in current_labthing().actions.to_dict()

# Send a DELETE request to terminate the task
with thing_client as c:
response = c.delete(f"/tasks/{task_id}")
print(response.json)
response = c.delete(f"/actions/{task_id}")
assert response.status_code == 200
# Test task was terminated
assert task_obj._status == "terminated"


def test_task_kill_missing(thing_client):
def test_action_kill_missing(thing_client):
with thing_client as c:
assert c.delete("/tasks/missing_id").status_code == 404
assert c.delete("/actions/missing_id").status_code == 404
Loading

0 comments on commit d67b5b0

Please sign in to comment.