diff --git a/luigi/rpc.py b/luigi/rpc.py index 4a80b7a5db..c0ae508283 100644 --- a/luigi/rpc.py +++ b/luigi/rpc.py @@ -19,7 +19,6 @@ rpc.py implements the client side of it, server.py implements the server side. See :doc:`/central_scheduler` for more info. """ - import json import logging import socket @@ -30,8 +29,7 @@ from luigi.six.moves.urllib.error import URLError from luigi import configuration -from luigi.scheduler import PENDING, Scheduler - +from luigi.scheduler import Scheduler, RPC_METHODS HAS_UNIX_SOCKET = True HAS_REQUESTS = True @@ -149,92 +147,6 @@ def _request(self, url, data, log_exceptions=True, attempts=3, allow_null=True): return response raise RPCError("Received null response from remote scheduler %r" % self._url) - def ping(self, worker): - # just one attempt, keep-alive thread will keep trying anyway - self._request('/api/ping', {'worker': worker}, attempts=1) - - def add_task(self, worker, task_id, status=PENDING, runnable=True, - deps=None, new_deps=None, expl=None, resources=None, priority=0, - family='', module=None, params=None, assistant=False, - tracking_url=None): - self._request('/api/add_task', { - 'task_id': task_id, - 'worker': worker, - 'status': status, - 'runnable': runnable, - 'deps': deps, - 'new_deps': new_deps, - 'expl': expl, - 'resources': resources, - 'priority': priority, - 'family': family, - 'module': module, - 'params': params, - 'assistant': assistant, - 'tracking_url': tracking_url, - }) - - def get_work(self, worker, host=None, assistant=False, current_tasks=None): - return self._request( - '/api/get_work', - { - 'worker': worker, - 'host': host, - 'assistant': assistant, - 'current_tasks': current_tasks, - }, - allow_null=False, - ) - - def graph(self): - return self._request('/api/graph', {}) - - def dep_graph(self, task_id, include_done=True): - return self._request('/api/dep_graph', {'task_id': task_id, 'include_done': include_done}) - - def inverse_dep_graph(self, task_id, include_done=True): - return self._request('/api/inverse_dep_graph', { - 'task_id': task_id, 'include_done': include_done}) - - def task_list(self, status, upstream_status, search=None): - return self._request('/api/task_list', { - 'search': search, - 'status': status, - 'upstream_status': upstream_status, - }) - - def worker_list(self): - return self._request('/api/worker_list', {}) - - def resource_list(self): - return self._request('/api/resource_list', {}) - - def task_search(self, task_str): - return self._request('/api/task_search', {'task_str': task_str}) - - def fetch_error(self, task_id): - return self._request('/api/fetch_error', {'task_id': task_id}) - - def add_worker(self, worker, info): - return self._request('/api/add_worker', {'worker': worker, 'info': info}) - - def disable_worker(self, worker): - return self._request('/api/disable_worker', {'worker': worker}) - - def update_resources(self, **resources): - return self._request('/api/update_resources', resources) - - def prune(self): - return self._request('/api/prune', {}) - - def re_enable_task(self, task_id): - return self._request('/api/re_enable_task', {'task_id': task_id}) - - def set_task_status_message(self, task_id, status_message): - self._request('/api/set_task_status_message', { - 'task_id': task_id, - 'status_message': status_message - }) - def get_task_status_message(self, task_id): - return self._request('/api/get_task_status_message', {'task_id': task_id}) +for method_name, method in RPC_METHODS.items(): + setattr(RemoteScheduler, method_name, method) diff --git a/luigi/scheduler.py b/luigi/scheduler.py index 00d4abc5f9..c1ad59911d 100644 --- a/luigi/scheduler.py +++ b/luigi/scheduler.py @@ -22,6 +22,8 @@ """ import collections +import inspect + try: import cPickle as pickle except ImportError: @@ -77,6 +79,36 @@ class Scheduler(object): TASK_FAMILY_RE = re.compile(r'([^(_]+)[(_]') +RPC_METHODS = {} + + +def rpc_method(**request_args): + def _rpc_method(fn): + # If request args are passed, return this function again for use as + # the decorator function with the request args attached. + fn_args = inspect.getargspec(fn) + + assert not fn_args.varargs + assert fn_args.args[0] == 'self' + all_args = fn_args.args[1:] + defaults = dict(zip(reversed(all_args), reversed(fn_args.defaults or ()))) + required_args = frozenset(arg for arg in all_args if arg not in defaults) + fn_name = fn.__name__ + + @functools.wraps(fn) + def rpc_func(self, *args, **kwargs): + actual_args = defaults.copy() + actual_args.update(dict(zip(all_args, args))) + actual_args.update(kwargs) + if not all(arg in actual_args for arg in required_args): + raise TypeError('{} takes {} arguments ({} given)'.format( + fn_name, len(all_args), len(actual_args))) + return self._request('/api/{}'.format(fn_name), actual_args, **request_args) + + RPC_METHODS[fn_name] = rpc_func + return fn + return _rpc_method + class scheduler(Config): # TODO(erikbern): the config_path is needed for backwards compatilibity. We @@ -521,6 +553,7 @@ def load(self): def dump(self): self._state.dump() + @rpc_method() def prune(self): logger.info("Starting pruning of task graph") self._prune_workers() @@ -575,10 +608,11 @@ def _update_priority(self, task, prio, worker): if t is not None and prio > t.priority: self._update_priority(t, prio, worker) + @rpc_method() def add_task(self, task_id=None, status=PENDING, runnable=True, deps=None, new_deps=None, expl=None, resources=None, priority=0, family='', module=None, params=None, - assistant=False, tracking_url=None, **kwargs): + assistant=False, tracking_url=None, worker=None, **kwargs): """ * add task identified by task_id if it doesn't exist * if deps is not None, update dependency list @@ -586,7 +620,8 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, * add additional workers/stakeholders * update priority when needed """ - worker_id = kwargs['worker'] + assert worker is not None + worker_id = worker worker_enabled = self.update(worker_id) if worker_enabled: @@ -655,12 +690,15 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, self._state.get_worker(worker_id).tasks.add(task) task.runnable = runnable + @rpc_method() def add_worker(self, worker, info, **kwargs): self._state.get_worker(worker).add_info(info) + @rpc_method() def disable_worker(self, worker): self._state.disable_workers({worker}) + @rpc_method() def update_resources(self, **resources): if self._resources is None: self._resources = {} @@ -706,7 +744,8 @@ def _schedulable(self, task): def _retry_time(self, task, config): return time.time() + config.retry_delay - def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs): + @rpc_method(allow_null=False) + def get_work(self, host=None, assistant=False, current_tasks=None, worker=None, **kwargs): # TODO: remove any expired nodes # Algo: iterate over all nodes, find the highest priority node no dependencies and available @@ -723,7 +762,8 @@ def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs): if self._config.prune_on_get_work: self.prune() - worker_id = kwargs['worker'] + assert worker is not None + worker_id = worker # Return remaining tasks that have no FAILED descendants self.update(worker_id, {'host': host}, get_work=True) if assistant: @@ -818,6 +858,7 @@ def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs): return reply + @rpc_method(attempts=1) def ping(self, **kwargs): worker_id = kwargs['worker'] self.update(worker_id) @@ -873,6 +914,7 @@ def _serialize_task(self, task_id, include_deps=True, deps=None): ret['deps'] = list(task.deps if deps is None else deps) return ret + @rpc_method() def graph(self, **kwargs): self.prune() serialized = {} @@ -948,12 +990,14 @@ def dep_func(t): return serialized + @rpc_method() def dep_graph(self, task_id, include_done=True, **kwargs): self.prune() if not self._state.has_task(task_id): return {} return self._traverse_graph(task_id, include_done=include_done) + @rpc_method() def inverse_dep_graph(self, task_id, include_done=True, **kwargs): self.prune() if not self._state.has_task(task_id): @@ -965,7 +1009,8 @@ def inverse_dep_graph(self, task_id, include_done=True, **kwargs): return self._traverse_graph( task_id, dep_func=lambda t: inverse_graph[t.id], include_done=include_done) - def task_list(self, status, upstream_status, limit=True, search=None, **kwargs): + @rpc_method() + def task_list(self, status='', upstream_status='', limit=True, search=None, **kwargs): """ Query for a subset of tasks by status. """ @@ -996,6 +1041,7 @@ def _first_task_display_name(self, worker): else: return task_id + @rpc_method() def worker_list(self, include_running=True, **kwargs): self.prune() workers = [ @@ -1027,6 +1073,7 @@ def worker_list(self, include_running=True, **kwargs): worker['running'] = tasks return workers + @rpc_method() def resource_list(self): """ Resources usage info and their consumers (tasks). @@ -1062,6 +1109,7 @@ def resources(self): ret[resource]['used'] = 0 return ret + @rpc_method() def task_search(self, task_str, **kwargs): """ Query for a subset of tasks by task_id. @@ -1077,6 +1125,7 @@ def task_search(self, task_str, **kwargs): result[task.status][task.id] = serialized return result + @rpc_method() def re_enable_task(self, task_id): serialized = {} task = self._state.get_task(task_id) @@ -1085,6 +1134,7 @@ def re_enable_task(self, task_id): serialized = self._serialize_task(task_id) return serialized + @rpc_method() def fetch_error(self, task_id, **kwargs): if self._state.has_task(task_id): task = self._state.get_task(task_id) @@ -1092,11 +1142,13 @@ def fetch_error(self, task_id, **kwargs): else: return {"taskId": task_id, "error": ""} + @rpc_method() def set_task_status_message(self, task_id, status_message): if self._state.has_task(task_id): task = self._state.get_task(task_id) task.status_message = status_message + @rpc_method() def get_task_status_message(self, task_id): if self._state.has_task(task_id): task = self._state.get_task(task_id) diff --git a/luigi/server.py b/luigi/server.py index ebcc234ecc..9c0c0ed0d4 100644 --- a/luigi/server.py +++ b/luigi/server.py @@ -53,8 +53,7 @@ import tornado.netutil import tornado.web -from luigi.scheduler import CentralPlannerScheduler - +from luigi.scheduler import CentralPlannerScheduler, RPC_METHODS logger = logging.getLogger("luigi.server") @@ -68,26 +67,7 @@ def initialize(self, scheduler): self._scheduler = scheduler def get(self, method): - if method not in [ - 'add_task', - 'add_worker', - 'dep_graph', - 'disable_worker', - 'fetch_error', - 'get_work', - 'graph', - 'inverse_dep_graph', - 'ping', - 'prune', - 're_enable_task', - 'resource_list', - 'task_list', - 'task_search', - 'update_resources', - 'worker_list', - 'set_task_status_message', - 'get_task_status_message', - ]: + if method not in RPC_METHODS: self.send_error(404) return payload = self.get_argument('data', default="{}")