Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removes redundant function definitions from rpc and server #1734

Merged
merged 4 commits into from
Jun 28, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 3 additions & 91 deletions luigi/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
rpc.py implements the client side of it, server.py implements the server side.
See :doc:`/central_scheduler` for more info.
"""

import json
import logging
import socket
Expand All @@ -30,8 +29,7 @@
from luigi.six.moves.urllib.error import URLError

from luigi import configuration
from luigi.scheduler import PENDING, Scheduler

from luigi.scheduler import Scheduler, RPC_METHODS

HAS_UNIX_SOCKET = True
HAS_REQUESTS = True
Expand Down Expand Up @@ -149,92 +147,6 @@ def _request(self, url, data, log_exceptions=True, attempts=3, allow_null=True):
return response
raise RPCError("Received null response from remote scheduler %r" % self._url)

def ping(self, worker):
# just one attempt, keep-alive thread will keep trying anyway
self._request('/api/ping', {'worker': worker}, attempts=1)

def add_task(self, worker, task_id, status=PENDING, runnable=True,
deps=None, new_deps=None, expl=None, resources=None, priority=0,
family='', module=None, params=None, assistant=False,
tracking_url=None):
self._request('/api/add_task', {
'task_id': task_id,
'worker': worker,
'status': status,
'runnable': runnable,
'deps': deps,
'new_deps': new_deps,
'expl': expl,
'resources': resources,
'priority': priority,
'family': family,
'module': module,
'params': params,
'assistant': assistant,
'tracking_url': tracking_url,
})

def get_work(self, worker, host=None, assistant=False, current_tasks=None):
return self._request(
'/api/get_work',
{
'worker': worker,
'host': host,
'assistant': assistant,
'current_tasks': current_tasks,
},
allow_null=False,
)

def graph(self):
return self._request('/api/graph', {})

def dep_graph(self, task_id, include_done=True):
return self._request('/api/dep_graph', {'task_id': task_id, 'include_done': include_done})

def inverse_dep_graph(self, task_id, include_done=True):
return self._request('/api/inverse_dep_graph', {
'task_id': task_id, 'include_done': include_done})

def task_list(self, status, upstream_status, search=None):
return self._request('/api/task_list', {
'search': search,
'status': status,
'upstream_status': upstream_status,
})

def worker_list(self):
return self._request('/api/worker_list', {})

def resource_list(self):
return self._request('/api/resource_list', {})

def task_search(self, task_str):
return self._request('/api/task_search', {'task_str': task_str})

def fetch_error(self, task_id):
return self._request('/api/fetch_error', {'task_id': task_id})

def add_worker(self, worker, info):
return self._request('/api/add_worker', {'worker': worker, 'info': info})

def disable_worker(self, worker):
return self._request('/api/disable_worker', {'worker': worker})

def update_resources(self, **resources):
return self._request('/api/update_resources', resources)

def prune(self):
return self._request('/api/prune', {})

def re_enable_task(self, task_id):
return self._request('/api/re_enable_task', {'task_id': task_id})

def set_task_status_message(self, task_id, status_message):
self._request('/api/set_task_status_message', {
'task_id': task_id,
'status_message': status_message
})

def get_task_status_message(self, task_id):
return self._request('/api/get_task_status_message', {'task_id': task_id})
for method_name, method in RPC_METHODS.items():
setattr(RemoteScheduler, method_name, method)
62 changes: 57 additions & 5 deletions luigi/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
"""

import collections
import inspect

try:
import cPickle as pickle
except ImportError:
Expand Down Expand Up @@ -77,6 +79,36 @@ class Scheduler(object):

TASK_FAMILY_RE = re.compile(r'([^(_]+)[(_]')

RPC_METHODS = {}


def rpc_method(**request_args):
def _rpc_method(fn):
# If request args are passed, return this function again for use as
# the decorator function with the request args attached.
fn_args = inspect.getargspec(fn)

assert not fn_args.varargs
assert fn_args.args[0] == 'self'
all_args = fn_args.args[1:]
defaults = dict(zip(reversed(all_args), reversed(fn_args.defaults or ())))
required_args = frozenset(arg for arg in all_args if arg not in defaults)
fn_name = fn.__name__

@functools.wraps(fn)
def rpc_func(self, *args, **kwargs):
actual_args = defaults.copy()
actual_args.update(dict(zip(all_args, args)))
actual_args.update(kwargs)
if not all(arg in actual_args for arg in required_args):
raise TypeError('{} takes {} arguments ({} given)'.format(
fn_name, len(all_args), len(actual_args)))
return self._request('/api/{}'.format(fn_name), actual_args, **request_args)

RPC_METHODS[fn_name] = rpc_func
return fn
return _rpc_method


class scheduler(Config):
# TODO(erikbern): the config_path is needed for backwards compatilibity. We
Expand Down Expand Up @@ -521,6 +553,7 @@ def load(self):
def dump(self):
self._state.dump()

@rpc_method()
def prune(self):
logger.info("Starting pruning of task graph")
self._prune_workers()
Expand Down Expand Up @@ -575,18 +608,20 @@ def _update_priority(self, task, prio, worker):
if t is not None and prio > t.priority:
self._update_priority(t, prio, worker)

@rpc_method()
def add_task(self, task_id=None, status=PENDING, runnable=True,
deps=None, new_deps=None, expl=None, resources=None,
priority=0, family='', module=None, params=None,
assistant=False, tracking_url=None, **kwargs):
assistant=False, tracking_url=None, worker=None, **kwargs):
"""
* add task identified by task_id if it doesn't exist
* if deps is not None, update dependency list
* update status of task
* add additional workers/stakeholders
* update priority when needed
"""
worker_id = kwargs['worker']
assert worker is not None
worker_id = worker
worker_enabled = self.update(worker_id)

if worker_enabled:
Expand Down Expand Up @@ -655,12 +690,15 @@ def add_task(self, task_id=None, status=PENDING, runnable=True,
self._state.get_worker(worker_id).tasks.add(task)
task.runnable = runnable

@rpc_method()
def add_worker(self, worker, info, **kwargs):
self._state.get_worker(worker).add_info(info)

@rpc_method()
def disable_worker(self, worker):
self._state.disable_workers({worker})

@rpc_method()
def update_resources(self, **resources):
if self._resources is None:
self._resources = {}
Expand Down Expand Up @@ -706,7 +744,8 @@ def _schedulable(self, task):
def _retry_time(self, task, config):
return time.time() + config.retry_delay

def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs):
@rpc_method(allow_null=False)
def get_work(self, host=None, assistant=False, current_tasks=None, worker=None, **kwargs):
# TODO: remove any expired nodes

# Algo: iterate over all nodes, find the highest priority node no dependencies and available
Expand All @@ -723,7 +762,8 @@ def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs):
if self._config.prune_on_get_work:
self.prune()

worker_id = kwargs['worker']
assert worker is not None
worker_id = worker
# Return remaining tasks that have no FAILED descendants
self.update(worker_id, {'host': host}, get_work=True)
if assistant:
Expand Down Expand Up @@ -818,6 +858,7 @@ def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs):

return reply

@rpc_method(attempts=1)
def ping(self, **kwargs):
worker_id = kwargs['worker']
self.update(worker_id)
Expand Down Expand Up @@ -873,6 +914,7 @@ def _serialize_task(self, task_id, include_deps=True, deps=None):
ret['deps'] = list(task.deps if deps is None else deps)
return ret

@rpc_method()
def graph(self, **kwargs):
self.prune()
serialized = {}
Expand Down Expand Up @@ -948,12 +990,14 @@ def dep_func(t):

return serialized

@rpc_method()
def dep_graph(self, task_id, include_done=True, **kwargs):
self.prune()
if not self._state.has_task(task_id):
return {}
return self._traverse_graph(task_id, include_done=include_done)

@rpc_method()
def inverse_dep_graph(self, task_id, include_done=True, **kwargs):
self.prune()
if not self._state.has_task(task_id):
Expand All @@ -965,7 +1009,8 @@ def inverse_dep_graph(self, task_id, include_done=True, **kwargs):
return self._traverse_graph(
task_id, dep_func=lambda t: inverse_graph[t.id], include_done=include_done)

def task_list(self, status, upstream_status, limit=True, search=None, **kwargs):
@rpc_method()
def task_list(self, status='', upstream_status='', limit=True, search=None, **kwargs):
"""
Query for a subset of tasks by status.
"""
Expand Down Expand Up @@ -996,6 +1041,7 @@ def _first_task_display_name(self, worker):
else:
return task_id

@rpc_method()
def worker_list(self, include_running=True, **kwargs):
self.prune()
workers = [
Expand Down Expand Up @@ -1027,6 +1073,7 @@ def worker_list(self, include_running=True, **kwargs):
worker['running'] = tasks
return workers

@rpc_method()
def resource_list(self):
"""
Resources usage info and their consumers (tasks).
Expand Down Expand Up @@ -1062,6 +1109,7 @@ def resources(self):
ret[resource]['used'] = 0
return ret

@rpc_method()
def task_search(self, task_str, **kwargs):
"""
Query for a subset of tasks by task_id.
Expand All @@ -1077,6 +1125,7 @@ def task_search(self, task_str, **kwargs):
result[task.status][task.id] = serialized
return result

@rpc_method()
def re_enable_task(self, task_id):
serialized = {}
task = self._state.get_task(task_id)
Expand All @@ -1085,18 +1134,21 @@ def re_enable_task(self, task_id):
serialized = self._serialize_task(task_id)
return serialized

@rpc_method()
def fetch_error(self, task_id, **kwargs):
if self._state.has_task(task_id):
task = self._state.get_task(task_id)
return {"taskId": task_id, "error": task.expl, 'displayName': task.pretty_id}
else:
return {"taskId": task_id, "error": ""}

@rpc_method()
def set_task_status_message(self, task_id, status_message):
if self._state.has_task(task_id):
task = self._state.get_task(task_id)
task.status_message = status_message

@rpc_method()
def get_task_status_message(self, task_id):
if self._state.has_task(task_id):
task = self._state.get_task(task_id)
Expand Down
24 changes: 2 additions & 22 deletions luigi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@
import tornado.netutil
import tornado.web

from luigi.scheduler import CentralPlannerScheduler

from luigi.scheduler import CentralPlannerScheduler, RPC_METHODS

logger = logging.getLogger("luigi.server")

Expand All @@ -68,26 +67,7 @@ def initialize(self, scheduler):
self._scheduler = scheduler

def get(self, method):
if method not in [
'add_task',
'add_worker',
'dep_graph',
'disable_worker',
'fetch_error',
'get_work',
'graph',
'inverse_dep_graph',
'ping',
'prune',
're_enable_task',
'resource_list',
'task_list',
'task_search',
'update_resources',
'worker_list',
'set_task_status_message',
'get_task_status_message',
]:
if method not in RPC_METHODS:
self.send_error(404)
return
payload = self.get_argument('data', default="{}")
Expand Down