diff --git a/core/dbt/contracts/rpc.py b/core/dbt/contracts/rpc.py index aa7165a2cc2..385c4bf2ad7 100644 --- a/core/dbt/contracts/rpc.py +++ b/core/dbt/contracts/rpc.py @@ -21,6 +21,7 @@ TaskTags = Optional[Dict[str, Any]] +TaskID = uuid.UUID # Inputs @@ -70,6 +71,56 @@ class RPCNoParameters(RPCParameters): pass +@dataclass +class KillParameters(RPCParameters): + task_id: TaskID + + +@dataclass +class PollParameters(RPCParameters): + request_token: TaskID + logs: bool = False + logs_start: int = 0 + + +@dataclass +class PSParameters(RPCParameters): + active: bool = True + completed: bool = False + + +@dataclass +class StatusParameters(RPCParameters): + pass + + +@dataclass +class GCSettings(JsonSchemaMixin): + # start evicting the longest-ago-ended tasks here + maxsize: int + # start evicting all tasks before now - auto_reap_age when we have this + # many tasks in the table + reapsize: int + # a positive timedelta indicating how far back we should go + auto_reap_age: timedelta + + +@dataclass +class GCParameters(RPCParameters): + """The gc endpoint takes three arguments, any of which may be present: + + - task_ids: An optional list of task ID UUIDs to try to GC + - before: If provided, should be a datetime string. All tasks that finished + before that datetime will be GCed + - settings: If provided, should be a GCSettings object in JSON form. It + will be applied to the task manager before GC starts. By default the + existing gc settings remain. + """ + task_ids: Optional[List[TaskID]] + before: Optional[datetime] + settings: Optional[GCSettings] + + # Outputs @dataclass @@ -133,12 +184,13 @@ class GCResultState(StrEnum): @dataclass -class GCResultSet(JsonSchemaMixin): - deleted: List[uuid.UUID] = field(default_factory=list) - missing: List[uuid.UUID] = field(default_factory=list) - running: List[uuid.UUID] = field(default_factory=list) +class GCResult(RemoteResult): + logs: List[LogMessage] = field(default_factory=list) + deleted: List[TaskID] = field(default_factory=list) + missing: List[TaskID] = field(default_factory=list) + running: List[TaskID] = field(default_factory=list) - def add_result(self, task_id: uuid.UUID, status: GCResultState): + def add_result(self, task_id: TaskID, status: GCResultState): if status == GCResultState.Missing: self.missing.append(task_id) elif status == GCResultState.Running: @@ -150,18 +202,6 @@ def add_result(self, task_id: uuid.UUID, status: GCResultState): f'Got invalid status in add_result: {status}' ) - -@dataclass -class GCSettings(JsonSchemaMixin): - # start evicting the longest-ago-ended tasks here - maxsize: int - # start evicting all tasks before now - auto_reap_age when we have this - # many tasks in the table - reapsize: int - # a positive timedelta indicating how far back we should go - auto_reap_age: timedelta - - # Task management types @@ -220,7 +260,7 @@ def finished(self) -> bool: @dataclass class TaskRow(JsonSchemaMixin): - task_id: uuid.UUID + task_id: TaskID request_id: Union[str, int] request_source: str method: str @@ -233,7 +273,7 @@ class TaskRow(JsonSchemaMixin): @dataclass -class PSResult(JsonSchemaMixin): +class PSResult(RemoteResult): rows: List[TaskRow] @@ -245,8 +285,9 @@ class KillResultStatus(StrEnum): @dataclass -class KillResult(JsonSchemaMixin): - status: KillResultStatus +class KillResult(RemoteResult): + status: KillResultStatus = KillResultStatus.Missing + logs: List[LogMessage] = field(default_factory=list) # this is kind of carefuly structured: BlocksManifestTasks is implied by @@ -256,13 +297,14 @@ class RemoteMethodFlags(enum.Flag): BlocksManifestTasks = 1 RequiresConfigReloadBefore = 3 RequiresManifestReloadAfter = 5 + Builtin = 8 # Polling types @dataclass -class PollResult(JsonSchemaMixin): +class PollResult(RemoteResult): tags: TaskTags = None status: TaskHandlerState = TaskHandlerState.NotStarted @@ -416,9 +458,9 @@ class ManifestStatus(StrEnum): @dataclass -class LastParse(JsonSchemaMixin): - status: ManifestStatus +class LastParse(RemoteResult): + status: ManifestStatus = ManifestStatus.Init + logs: List[LogMessage] = field(default_factory=list) error: Optional[Dict[str, Any]] = None - logs: Optional[List[Dict[str, Any]]] = None timestamp: datetime = field(default_factory=datetime.utcnow) pid: int = field(default_factory=os.getpid) diff --git a/core/dbt/rpc/builtins.py b/core/dbt/rpc/builtins.py new file mode 100644 index 00000000000..b9e97659a9f --- /dev/null +++ b/core/dbt/rpc/builtins.py @@ -0,0 +1,227 @@ +import os +import signal +from datetime import datetime +from typing import Type, Union, Any, List + +import dbt.exceptions +from dbt.contracts.rpc import ( + TaskTags, + StatusParameters, + LastParse, + GCParameters, + GCResult, + KillParameters, + KillResult, + KillResultStatus, + PSParameters, + TaskRow, + PSResult, + RemoteExecutionResult, + RemoteRunResult, + RemoteCompileResult, + RemoteCatalogResults, + RemoteEmptyResult, + PollParameters, + PollResult, + PollInProgressResult, + PollKilledResult, + PollExecuteCompleteResult, + PollRunCompleteResult, + PollCompileCompleteResult, + PollCatalogCompleteResult, + PollRemoteEmptyCompleteResult, + TaskHandlerState, +) +from dbt.logger import LogMessage +from dbt.rpc.error import dbt_error, RPCException +from dbt.rpc.method import RemoteBuiltinMethod +from dbt.rpc.task_handler import RequestTaskHandler + + +class GC(RemoteBuiltinMethod[GCParameters, GCResult]): + METHOD_NAME = 'gc' + + def set_args(self, params: GCParameters): + super().set_args(params) + + def handle_request(self) -> GCResult: + if self.params is None: + raise dbt.exceptions.InternalException('GC: params not set') + return self.task_manager.gc_safe( + task_ids=self.params.task_ids, + before=self.params.before, + settings=self.params.settings, + ) + + +class Kill(RemoteBuiltinMethod[KillParameters, KillResult]): + METHOD_NAME = 'kill' + + def set_args(self, params: KillParameters): + super().set_args(params) + + def handle_request(self) -> KillResult: + if self.params is None: + raise dbt.exceptions.InternalException('Kill: params not set') + result = KillResult() + task: RequestTaskHandler + try: + task = self.task_manager.get_request(self.params.task_id) + except dbt.exceptions.UnknownAsyncIDException: + # nothing to do! + return result + + result.status = KillResultStatus.NotStarted + + if task.process is None: + return result + pid = task.process.pid + if pid is None: + return result + + if task.process.is_alive(): + result.status = KillResultStatus.Killed + task.ended = datetime.utcnow() + os.kill(pid, signal.SIGINT) + task.state = TaskHandlerState.Killed + else: + result.status = KillResultStatus.Finished + # the status must be "Completed" + + return result + + +class Status(RemoteBuiltinMethod[StatusParameters, LastParse]): + METHOD_NAME = 'status' + + def set_args(self, params: StatusParameters): + super().set_args(params) + + def handle_request(self) -> LastParse: + return self.task_manager.last_parse + + +class PS(RemoteBuiltinMethod[PSParameters, PSResult]): + METHOD_NAME = 'ps' + + def set_args(self, params: PSParameters): + super().set_args(params) + + def keep(self, row: TaskRow): + if self.params is None: + raise dbt.exceptions.InternalException('PS: params not set') + if row.state.finished and self.params.completed: + return True + elif not row.state.finished and self.params.active: + return True + else: + return False + + def handle_request(self) -> PSResult: + rows = [ + row for row in self.task_manager.task_table() if self.keep(row) + ] + rows.sort(key=lambda r: (r.state, r.start, r.method)) + result = PSResult(rows=rows, logs=[]) + return result + + +def poll_complete( + status: TaskHandlerState, result: Any, tags: TaskTags +) -> PollResult: + if status not in (TaskHandlerState.Success, TaskHandlerState.Failed): + raise dbt.exceptions.InternalException( + 'got invalid result status in poll_complete: {}'.format(status) + ) + + cls: Type[Union[ + PollExecuteCompleteResult, + PollRunCompleteResult, + PollCompileCompleteResult, + PollCatalogCompleteResult, + PollRemoteEmptyCompleteResult, + ]] + + if isinstance(result, RemoteExecutionResult): + cls = PollExecuteCompleteResult + # order matters here, as RemoteRunResult subclasses RemoteCompileResult + elif isinstance(result, RemoteRunResult): + cls = PollRunCompleteResult + elif isinstance(result, RemoteCompileResult): + cls = PollCompileCompleteResult + elif isinstance(result, RemoteCatalogResults): + cls = PollCatalogCompleteResult + elif isinstance(result, RemoteEmptyResult): + cls = PollRemoteEmptyCompleteResult + else: + raise dbt.exceptions.InternalException( + 'got invalid result in poll_complete: {}'.format(result) + ) + return cls.from_result(status, result, tags) + + +class Poll(RemoteBuiltinMethod[PollParameters, PollResult]): + METHOD_NAME = 'poll' + + def set_args(self, params: PollParameters): + super().set_args(params) + + def handle_request(self) -> PollResult: + if self.params is None: + raise dbt.exceptions.InternalException('Poll: params not set') + task_id = self.params.request_token + task = self.task_manager.get_request(task_id) + + task_logs: List[LogMessage] = [] + if self.params.logs: + task_logs = task.logs[self.params.logs_start:] + + # Get a state and store it locally so we ignore updates to state, + # otherwise things will get confusing. States should always be + # "forward-compatible" so if the state has transitioned to error/result + # but we aren't there yet, the logs will still be valid. + state = task.state + if state <= TaskHandlerState.Running: + return PollInProgressResult( + status=state, + tags=task.tags, + logs=task_logs, + ) + elif state == TaskHandlerState.Error: + err = task.error + if err is None: + exc = dbt.exceptions.InternalException( + f'At end of task {task_id}, error state but error is None' + ) + raise RPCException.from_error( + dbt_error(exc, logs=[l.to_dict() for l in task_logs]) + ) + # the exception has logs already attached from the child, don't + # overwrite those + raise err + elif state in (TaskHandlerState.Success, TaskHandlerState.Failed): + + if task.result is None: + exc = dbt.exceptions.InternalException( + f'At end of task {task_id}, state={state} but result is ' + 'None' + ) + raise RPCException.from_error( + dbt_error(exc, logs=[l.to_dict() for l in task_logs]) + ) + return poll_complete( + status=state, + result=task.result, + tags=task.tags, + ) + elif state == TaskHandlerState.Killed: + return PollKilledResult( + status=state, tags=task.tags, logs=task_logs + ) + else: + exc = dbt.exceptions.InternalException( + f'Got unknown value state={state} for task {task_id}' + ) + raise RPCException.from_error( + dbt_error(exc, logs=[l.to_dict() for l in task_logs]) + ) diff --git a/core/dbt/rpc/gc.py b/core/dbt/rpc/gc.py new file mode 100644 index 00000000000..291549dcdf8 --- /dev/null +++ b/core/dbt/rpc/gc.py @@ -0,0 +1,161 @@ +import multiprocessing +import operator +from datetime import datetime, timedelta +from typing import ( + MutableMapping, Optional, List, Iterable, Tuple, Union, +) +from typing_extensions import Protocol + +import dbt.exceptions +import dbt.flags +from dbt.contracts.rpc import ( + GCSettings, + GCResultState, + GCResult, + TaskHandlerState, + TaskID, + TaskTags, +) + +# import this to make sure our timedelta encoder is registered +from dbt import helper_types # noqa + + +class Collectible(Protocol): + started: Optional[datetime] + ended: Optional[datetime] + state: TaskHandlerState + task_id: TaskID + process: Optional[multiprocessing.Process] + + @property + def request_id(self) -> Union[str, int]: + pass + + @property + def request_source(self) -> str: + pass + + @property + def timeout(self) -> Optional[float]: + pass + + @property + def method(self) -> str: + pass + + @property + def tags(self) -> Optional[TaskTags]: + pass + + +class GarbageCollector: + def __init__( + self, + active_tasks: MutableMapping[TaskID, Collectible], + settings: Optional[GCSettings] = None, + ) -> None: + self.active_tasks: MutableMapping[TaskID, Collectible] = active_tasks + self.settings: GCSettings + + if settings is None: + self.settings = GCSettings( + maxsize=1000, reapsize=500, auto_reap_age=timedelta(days=30) + ) + else: + self.settings = settings + + def _remove_task_if_finished(self, task_id: TaskID) -> GCResultState: + """Remove the task if it was finished. Raises a KeyError if the entry + is removed during operation (so hold the lock). + """ + if task_id not in self.active_tasks: + return GCResultState.Missing + + task = self.active_tasks[task_id] + if not task.state.finished: + return GCResultState.Running + + del self.active_tasks[task_id] + return GCResultState.Deleted + + def _get_before_list(self, when: datetime) -> List[TaskID]: + removals: List[TaskID] = [] + for task in self.active_tasks.values(): + if not task.state.finished: + continue + elif task.ended is None: + continue + elif task.ended < when: + removals.append(task.task_id) + + return removals + + def _get_oldest_ended_list(self, num: int) -> List[TaskID]: + candidates: List[Tuple[datetime, TaskID]] = [] + for task in self.active_tasks.values(): + if not task.state.finished: + continue + elif task.ended is None: + continue + else: + candidates.append((task.ended, task.task_id)) + candidates.sort(key=operator.itemgetter(0)) + return [task_id for _, task_id in candidates[:num]] + + def collect_task_id( + self, result: GCResult, task_id: TaskID + ) -> None: + """To collect a task ID, we just delete it from the tasks dict. + + You must hold the lock, as this mutates `tasks`. + """ + try: + status = self._remove_task_if_finished(task_id) + except KeyError: + # someone was mutating tasks while we had the lock, that's + # not right! + raise dbt.exceptions.InternalException( + 'Got a KeyError for task uuid={} during gc' + .format(task_id) + ) + + return result.add_result(task_id=task_id, status=status) + + def collect_multiple_task_ids( + self, task_ids: Iterable[TaskID] + ) -> GCResult: + result = GCResult() + for task_id in task_ids: + self.collect_task_id(result, task_id) + return result + + def collect_as_required(self) -> None: + to_remove: List[TaskID] = [] + num_tasks = len(self.active_tasks) + if num_tasks > self.settings.maxsize: + num = self.settings.maxsize - num_tasks + to_remove = self._get_oldest_ended_list(num) + elif num_tasks > self.settings.reapsize: + before = datetime.utcnow() - self.settings.auto_reap_age + to_remove = self._get_before_list(before) + + if to_remove: + self.collect_multiple_task_ids(to_remove) + + def collect_selected( + self, + task_ids: Optional[List[TaskID]] = None, + before: Optional[datetime] = None, + settings: Optional[GCSettings] = None, + ) -> GCResult: + to_gc = set() + + if task_ids is not None: + to_gc.update(task_ids) + if settings: + self.settings = settings + # we need the lock for this! + if before is not None: + to_gc.update(self._get_before_list(before)) + return self.collect_multiple_task_ids(to_gc) diff --git a/core/dbt/rpc/method.py b/core/dbt/rpc/method.py index c9e0d77d07c..fb1c6d83d59 100644 --- a/core/dbt/rpc/method.py +++ b/core/dbt/rpc/method.py @@ -1,7 +1,8 @@ import inspect from abc import abstractmethod -from typing import List, Optional, Type, TypeVar, Generic -from typing import Any # noqa +from typing import List, Optional, Type, TypeVar, Generic, Dict, Any + +from hologram import JsonSchemaMixin, ValidationError from dbt.contracts.rpc import RPCParameters, RemoteResult, RemoteMethodFlags from dbt.exceptions import NotImplementedException, InternalException @@ -93,24 +94,55 @@ def __init__(self, args, config, manifest): self.manifest = manifest -class TaskList(List[Type[RemoteMethod]]): +class RemoteBuiltinMethod(RemoteMethod[Parameters, Result]): + def __init__(self, task_manager): + self.task_manager = task_manager + super().__init__(task_manager.args, task_manager.config) + self.params: Optional[Parameters] = None + + def set_args(self, params: Parameters): + self.params = params + + def __call__(self, **kwargs: Dict[str, Any]) -> JsonSchemaMixin: + try: + params = self.get_parameters().from_dict(kwargs) + except ValidationError as exc: + raise TypeError(exc) from exc + self.set_args(params) + return self.handle_request() + + +class TaskTypes(Dict[str, Type[RemoteMethod]]): def __init__( - self, - tasks: Optional[List[Type[RemoteMethod]]] = None - ): + self, tasks: Optional[List[Type[RemoteMethod]]] = None + ) -> None: task_list: List[Type[RemoteMethod]] if tasks is None: task_list = RemoteMethod.recursive_subclasses(named_only=True) else: task_list = tasks - return super().__init__(task_list) - - def manifest(self) -> List[Type[RemoteManifestMethod]]: - return [ - t for t in self if issubclass(t, RemoteManifestMethod) - ] - - def non_manifest(self) -> List[Type[RemoteMethod]]: - return [ - t for t in self if not issubclass(t, RemoteManifestMethod) - ] + super().__init__( + (t.METHOD_NAME, t) for t in task_list + if t.METHOD_NAME is not None + ) + + def manifest(self) -> Dict[str, Type[RemoteManifestMethod]]: + return { + k: t for k, t in self.items() + if issubclass(t, RemoteManifestMethod) + } + + def builtin(self) -> Dict[str, Type[RemoteBuiltinMethod]]: + return { + k: t for k, t in self.items() + if issubclass(t, RemoteBuiltinMethod) + } + + def non_manifest(self) -> Dict[str, Type[RemoteMethod]]: + return { + k: t for k, t in self.items() + if ( + not issubclass(t, RemoteManifestMethod) and + not issubclass(t, RemoteBuiltinMethod) + ) + } diff --git a/core/dbt/rpc/response_manager.py b/core/dbt/rpc/response_manager.py index a53a101564f..7bf9ae746f9 100644 --- a/core/dbt/rpc/response_manager.py +++ b/core/dbt/rpc/response_manager.py @@ -50,6 +50,9 @@ def __getitem__(self, key) -> Callable[..., Dict[str, Any]]: ) if handler is None: raise KeyError(key) + if callable(handler): + # either an error or a builtin + return handler elif isinstance(handler, RemoteMethod): # the handler must be a task. Wrap it in a task handler so it can # go async @@ -57,7 +60,10 @@ def __getitem__(self, key) -> Callable[..., Dict[str, Any]]: self.manager, handler, self.http_request, self.json_rpc_request ) else: - return handler + raise dbt.exceptions.InternalException( + f'Got an invalid handler from get_handler. Expected None, ' + f'callable, or RemoteMethod, got {handler}' + ) class ResponseManager(JSONRPCResponseManager): diff --git a/core/dbt/rpc/task_handler.py b/core/dbt/rpc/task_handler.py index a45b963d069..fd9536bb12a 100644 --- a/core/dbt/rpc/task_handler.py +++ b/core/dbt/rpc/task_handler.py @@ -17,7 +17,7 @@ cleanup_connections, load_plugin, register_adapter, ) from dbt.contracts.rpc import ( - RPCParameters, RemoteResult, TaskHandlerState, RemoteMethodFlags, + RPCParameters, RemoteResult, TaskHandlerState, RemoteMethodFlags, TaskTags, ) from dbt.logger import ( GLOBAL_LOGGER as logger, list_handler, LogMessage, OutputHandler, @@ -28,6 +28,7 @@ RPCException, timeout_error, ) +from dbt.rpc.gc import Collectible from dbt.rpc.logger import ( QueueSubscriber, QueueLogHandler, @@ -125,11 +126,11 @@ def set_parsing(self): pass def set_compile_exception( - self, exc: Exception, logs: List[Dict[str, Any]] + self, exc: Exception, logs: List[LogMessage] ): pass - def set_ready(self, logs: List[Dict[str, Any]]): + def set_ready(self, logs: List[LogMessage]): pass def add_request(self, request: 'RequestTaskHandler') -> Dict[str, Any]: @@ -153,16 +154,10 @@ def set_parse_state_with( try: yield except Exception as exc: - log_dicts = [r.to_dict() for r in logs()] - manager.set_compile_exception(exc, logs=log_dicts) - # re-raise to ensure any exception handlers above trigger. We might be - # in an API call that set the parse state, in which case we don't want - # to swallow the exception - it also should report its failure to the - # task manager. + manager.set_compile_exception(exc, logs=logs()) raise else: - log_dicts = [r.to_dict() for r in logs()] - manager.set_ready(log_dicts) + manager.set_ready(logs=logs()) @contextmanager @@ -258,7 +253,7 @@ def handle_teardown(self): pass -class RequestTaskHandler(threading.Thread): +class RequestTaskHandler(threading.Thread, Collectible): """Handler for the single task triggered by a given jsonrpc request.""" def __init__( self, @@ -324,7 +319,7 @@ def timeout(self) -> Optional[float]: return float(self.task_params.timeout) @property - def tags(self) -> Optional[Dict[str, Any]]: + def tags(self) -> Optional[TaskTags]: if self.task_params is None: return None return self.task_params.task_tags @@ -470,6 +465,10 @@ def handle(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: 'Task params set to None!' ) + if RemoteMethodFlags.Builtin in flags: + # bypass the queue, logging, etc: Straight to the method + return self.task.handle_request() + self.subscriber = QueueSubscriber(dbt.flags.MP_CONTEXT.Queue()) self.process = BootstrapProcess(self.task, self.subscriber.queue) @@ -486,7 +485,7 @@ def handle(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: self.start() return {'request_token': str(self.task_id)} - def __call__(self, **kwargs) -> Dict[str, Any]: + def __call__(self, **kwargs: Dict[str, Any]) -> Dict[str, Any]: # __call__ happens deep inside jsonrpc's framework self.manager.add_request(self) return self.handle(kwargs) diff --git a/core/dbt/rpc/task_manager.py b/core/dbt/rpc/task_manager.py index 29f1da8a170..9b50b4bbb6e 100644 --- a/core/dbt/rpc/task_manager.py +++ b/core/dbt/rpc/task_manager.py @@ -1,52 +1,35 @@ -import operator -import os -import signal import threading import uuid -from dataclasses import dataclass -from datetime import datetime, timedelta -from functools import wraps +from datetime import datetime from typing import ( - Any, Dict, Optional, List, Union, Set, Callable, Iterable, Tuple, Type, + Any, Dict, Optional, List, Union, Set, Callable, Type, MutableMapping ) -from hologram import JsonSchemaMixin, ValidationError import dbt.exceptions import dbt.flags from dbt.contracts.graph.manifest import Manifest from dbt.contracts.rpc import ( - TaskTags, LastParse, ManifestStatus, GCSettings, - KillResult, - KillResultStatus, - GCResultState, - GCResultSet, + GCResult, TaskRow, - PSResult, - RemoteExecutionResult, - RemoteRunResult, - RemoteCompileResult, - RemoteCatalogResults, - RemoteEmptyResult, - PollResult, - PollInProgressResult, - PollKilledResult, - PollExecuteCompleteResult, - PollRunCompleteResult, - PollCompileCompleteResult, - PollCatalogCompleteResult, - PollRemoteEmptyCompleteResult, + TaskHandlerState, + TaskID, ) from dbt.logger import LogMessage, list_handler from dbt.perf_utils import get_full_manifest -from dbt.rpc.error import dbt_error, RPCException +from dbt.rpc.error import dbt_error +from dbt.rpc.gc import GarbageCollector, Collectible from dbt.rpc.task_handler import ( - TaskHandlerState, RequestTaskHandler, set_parse_state_with + RequestTaskHandler, set_parse_state_with ) -from dbt.rpc.method import RemoteMethod, RemoteManifestMethod, TaskList +from dbt.rpc.method import ( + RemoteMethod, RemoteManifestMethod, RemoteBuiltinMethod, TaskTypes, +) +# pick up our builtin methods +import dbt.rpc.builtins # noqa # import this to make sure our timedelta encoder is registered @@ -57,7 +40,12 @@ SINGLE_THREADED_WEBSERVER = env_set_truthy('DBT_SINGLE_THREADED_WEBSERVER') -def _assert_started(task_handler: RequestTaskHandler) -> datetime: +WrappedHandler = Callable[..., Dict[str, Any]] + +TaskDict = MutableMapping[uuid.UUID, RequestTaskHandler] + + +def _assert_started(task_handler: Collectible) -> datetime: if task_handler.started is None: raise dbt.exceptions.InternalException( 'task handler started but start time is not set' @@ -65,7 +53,7 @@ def _assert_started(task_handler: RequestTaskHandler) -> datetime: return task_handler.started -def _assert_ended(task_handler: RequestTaskHandler) -> datetime: +def _assert_ended(task_handler: Collectible) -> datetime: if task_handler.ended is None: raise dbt.exceptions.InternalException( 'task handler finished but end time is not set' @@ -73,7 +61,7 @@ def _assert_ended(task_handler: RequestTaskHandler) -> datetime: return task_handler.ended -def make_task(task_handler: RequestTaskHandler, now_time: datetime) -> TaskRow: +def make_task(task_handler: Collectible, now_time: datetime) -> TaskRow: # get information about the task in a way that should not provide any # conflicting information. Calculate elapsed time based on `now_time` state = task_handler.state @@ -106,20 +94,24 @@ def make_task(task_handler: RequestTaskHandler, now_time: datetime) -> TaskRow: ) -UnmanagedHandler = Callable[..., JsonSchemaMixin] -WrappedHandler = Callable[..., Dict[str, Any]] +class UnconditionalError: + def __init__(self, exception: dbt.exceptions.Exception): + self.exception = dbt_error(exception) + def __call__(self, *args, **kwargs): + raise self.exception -def _wrap_builtin(func: UnmanagedHandler) -> WrappedHandler: - @wraps(func) - def inner(*args, **kwargs): - return func(*args, **kwargs).to_dict(omit_none=False) - return inner +class ParseError(UnconditionalError): + def __init__(self, parse_error): + exception = dbt.exceptions.RPCLoadException(parse_error) + super().__init__(exception) -class Reserved: - # a dummy class - pass + +class CurrentlyCompiling(UnconditionalError): + def __init__(self): + exception = dbt.exceptions.RPCCompiling('compile in progress') + super().__init__(exception) class ManifestReloader(threading.Thread): @@ -141,62 +133,21 @@ def run(self) -> None: pass -@dataclass -class _GCArguments(JsonSchemaMixin): - """An argument validation helper""" - task_ids: Optional[List[uuid.UUID]] - before: Optional[datetime] - settings: Optional[GCSettings] - - -def poll_complete( - status: TaskHandlerState, result: Any, tags: TaskTags -) -> PollResult: - if status not in (TaskHandlerState.Success, TaskHandlerState.Failed): - raise dbt.exceptions.InternalException( - 'got invalid result status in poll_complete: {}'.format(status) - ) - - cls: Type[Union[ - PollExecuteCompleteResult, - PollRunCompleteResult, - PollCompileCompleteResult, - PollCatalogCompleteResult, - PollRemoteEmptyCompleteResult, - ]] - - if isinstance(result, RemoteExecutionResult): - cls = PollExecuteCompleteResult - # order matters here, as RemoteRunResult subclasses RemoteCompileResult - elif isinstance(result, RemoteRunResult): - cls = PollRunCompleteResult - elif isinstance(result, RemoteCompileResult): - cls = PollCompileCompleteResult - elif isinstance(result, RemoteCatalogResults): - cls = PollCatalogCompleteResult - elif isinstance(result, RemoteEmptyResult): - cls = PollRemoteEmptyCompleteResult - else: - raise dbt.exceptions.InternalException( - 'got invalid result in poll_complete: {}'.format(result) - ) - return cls.from_result(status, result, tags) - - class TaskManager: - def __init__(self, args, config, task_types: TaskList) -> None: + def __init__(self, args, config, task_types: TaskTypes) -> None: self.args = args self.config = config - self._task_types: TaskList = task_types - self.active_tasks: Dict[uuid.UUID, RequestTaskHandler] = {} - self._rpc_task_map: Dict[str, Union[Reserved, RemoteMethod]] = {} - self._builtins: Dict[str, UnmanagedHandler] = {} + self.manifest: Optional[Manifest] = None + self._task_types: TaskTypes = task_types + self.active_tasks: MutableMapping[uuid.UUID, Collectible] = {} + self.gc = GarbageCollector(active_tasks=self.active_tasks) self.last_parse: LastParse = LastParse(status=ManifestStatus.Init) self._lock: dbt.flags.MP_CONTEXT.Lock = dbt.flags.MP_CONTEXT.Lock() - self._gc_settings: GCSettings = GCSettings( - maxsize=1000, reapsize=500, auto_reap_age=timedelta(days=30) - ) self._reloader: Optional[ManifestReloader] = None + self.reload_manifest() + + def single_threaded(self): + return SINGLE_THREADED_WEBSERVER or self.args.single_threaded def _reload_task_manager_thread(self, reloader: ManifestReloader): """This function can only be running once at a time, as it runs in the @@ -212,7 +163,7 @@ def _reload_task_manager_fg(self, reloader: ManifestReloader): # just reload directly reloader.reload_manifest() - def reload_manifest_tasks(self) -> bool: + def reload_manifest(self) -> bool: """Reload the manifest using a manifest reloader. Returns False if the reload was not started because it was already running. """ @@ -229,77 +180,60 @@ def reload_manifest_tasks(self) -> bool: self._reload_task_manager_thread(reloader) return True - def single_threaded(self): - return SINGLE_THREADED_WEBSERVER or self.args.single_threaded - - def reload_non_manifest_tasks(self): + def reload_builtin_tasks(self): # reload all the non-manifest tasks because the config changed. # manifest tasks are still blocked so we can ignore them - for task_cls in self._task_types.non_manifest(): - self.add_basic_task_handler(task_cls) + for task_cls in self._task_types.builtin(): + self.add_builtin_task_handler(task_cls) def reload_config(self): config = self.config.from_args(self.args) self.config = config - # reload all the non-manifest tasks because the config changed. - # manifest tasks are still blocked so we can ignore them - self.reload_non_manifest_tasks() return config def add_request(self, request_handler: RequestTaskHandler): self.active_tasks[request_handler.task_id] = request_handler - def reserve_handler(self, task: Type[RemoteMethod]) -> None: - """Reserved tasks will return a status indicating that the manifest is - compiling. - """ - if task.METHOD_NAME is None: - raise dbt.exceptions.InternalException( - f'Cannot add task {task} as it has no method name' - ) - self._rpc_task_map[task.METHOD_NAME] = Reserved() - - def _check_task_handler(self, task_type: Type[RemoteMethod]) -> None: - if task_type.METHOD_NAME is None: - raise dbt.exceptions.InternalException( - 'Task {} has no method name, cannot add it'.format(task_type) - ) - method = task_type.METHOD_NAME - if method not in self._rpc_task_map: - # this is weird, but hey whatever - return - other_task = self._rpc_task_map[method] - if isinstance(other_task, Reserved) or type(other_task) is task_type: - return - raise dbt.exceptions.InternalException( - 'Got two tasks with the same method name! {0} and {1} both ' - 'have a method name of {0.METHOD_NAME}, but RPC method names ' - 'should be unique'.format(task_type, other_task) - ) - - def add_manifest_task_handler( - self, task: Type[RemoteManifestMethod], manifest: Manifest - ) -> None: - self._check_task_handler(task) - assert task.METHOD_NAME is not None - self._rpc_task_map[task.METHOD_NAME] = task( - self.args, self.config, manifest - ) - - def add_basic_task_handler(self, task: Type[RemoteMethod]) -> None: - if issubclass(task, RemoteManifestMethod): - raise dbt.exceptions.InternalException( - f'Task {task} requires a manifest, cannot add it as a basic ' - f'handler' - ) + def get_request(self, task_id: TaskID) -> Collectible: + try: + return self.active_tasks[task_id] + except KeyError: + # We don't recognize that ID. + raise dbt.exceptions.UnknownAsyncIDException(task_id) from None - self._check_task_handler(task) - assert task.METHOD_NAME is not None - self._rpc_task_map[task.METHOD_NAME] = task(self.args, self.config) + def _get_manifest_callable( + self, task: Type[RemoteManifestMethod] + ) -> Union[UnconditionalError, RemoteManifestMethod]: + status = self.last_parse.status + if status == ManifestStatus.Compiling: + return CurrentlyCompiling() + elif status == ManifestStatus.Error: + return ParseError(self.last_parse.error) + else: + if self.manifest is None: + raise dbt.exceptions.InternalException( + f'Manifest should not be None if the last parse status is ' + f'{status}' + ) + return task(self.args, self.config, self.manifest) - def rpc_task(self, method_name: str) -> Union[Reserved, RemoteMethod]: + def rpc_task( + self, method_name: str + ) -> Union[UnconditionalError, RemoteMethod]: with self._lock: - return self._rpc_task_map[method_name] + task = self._task_types[method_name] + if issubclass(task, RemoteBuiltinMethod): + return task(self) + elif issubclass(task, RemoteManifestMethod): + return self._get_manifest_callable(task) + elif issubclass(task, RemoteMethod): + return task(self.args, self.config) + else: + raise dbt.exceptions.InternalException( + f'Got a task with an invalid type! {task} with method ' + f'name {method_name} has a type of {task.__class__}, ' + f'should be a RemoteMethod' + ) def ready(self) -> bool: with self._lock: @@ -310,20 +244,12 @@ def set_parsing(self) -> bool: if self.last_parse.status == ManifestStatus.Compiling: return False self.last_parse = LastParse(status=ManifestStatus.Compiling) - for task in self._task_types.manifest(): - # reserve any tasks that are invalid - self.reserve_handler(task) return True def parse_manifest(self) -> None: - manifest = get_full_manifest(self.config) + self.manifest = get_full_manifest(self.config) - for task_cls in self._task_types.manifest(): - self.add_manifest_task_handler( - task_cls, manifest - ) - - def set_compile_exception(self, exc, logs=List[Dict[str, Any]]) -> None: + def set_compile_exception(self, exc, logs=List[LogMessage]) -> None: assert self.last_parse.status == ManifestStatus.Compiling, \ f'invalid state {self.last_parse.status}' self.last_parse = LastParse( @@ -332,7 +258,7 @@ def set_compile_exception(self, exc, logs=List[Dict[str, Any]]) -> None: logs=logs ) - def set_ready(self, logs=List[Dict[str, Any]]) -> None: + def set_ready(self, logs=List[LogMessage]) -> None: assert self.last_parse.status == ManifestStatus.Compiling, \ f'invalid state {self.last_parse.status}' self.last_parse = LastParse( @@ -340,154 +266,9 @@ def set_ready(self, logs=List[Dict[str, Any]]) -> None: logs=logs ) - def process_status(self) -> LastParse: + def methods(self) -> Set[str]: with self._lock: - last_compile = self.last_parse - return last_compile - - def process_ps( - self, - active: bool = True, - completed: bool = False, - ) -> PSResult: - rows = [] - now = datetime.utcnow() - with self._lock: - for task in self.active_tasks.values(): - row = make_task(task, now) - if row.state.finished and completed: - rows.append(row) - elif not row.state.finished and active: - rows.append(row) - - rows.sort(key=lambda r: (r.state, r.start, r.method)) - result = PSResult(rows=rows) - return result - - def process_kill(self, task_id: str) -> KillResult: - task_id_uuid = uuid.UUID(task_id) - - status = KillResultStatus.Missing - try: - task: RequestTaskHandler = self.active_tasks[task_id_uuid] - except KeyError: - # nothing to do! - return KillResult(status) - - status = KillResultStatus.NotStarted - - if task.process is None: - return KillResult(status) - pid = task.process.pid - if pid is None: - return KillResult(status) - - if task.process.is_alive(): - status = KillResultStatus.Killed - task.ended = datetime.utcnow() - os.kill(pid, signal.SIGINT) - task.state = TaskHandlerState.Killed - else: - status = KillResultStatus.Finished - # the status must be "Completed" - - return KillResult(status) - - def process_poll( - self, - request_token: str, - logs: bool = False, - logs_start: int = 0, - ) -> PollResult: - task_id = uuid.UUID(request_token) - try: - task: RequestTaskHandler = self.active_tasks[task_id] - except KeyError: - # We don't recognize that ID. - raise dbt.exceptions.UnknownAsyncIDException(task_id) from None - - task_logs: List[LogMessage] = [] - if logs: - task_logs = task.logs[logs_start:] - - # Get a state and store it locally so we ignore updates to state, - # otherwise things will get confusing. States should always be - # "forward-compatible" so if the state has transitioned to error/result - # but we aren't there yet, the logs will still be valid. - state = task.state - if state <= TaskHandlerState.Running: - return PollInProgressResult( - status=state, - tags=task.tags, - logs=task_logs, - ) - elif state == TaskHandlerState.Error: - err = task.error - if err is None: - exc = dbt.exceptions.InternalException( - f'At end of task {task_id}, error state but error is None' - ) - raise RPCException.from_error( - dbt_error(exc, logs=[l.to_dict() for l in task_logs]) - ) - # the exception has logs already attached from the child, don't - # overwrite those - raise err - elif state in (TaskHandlerState.Success, TaskHandlerState.Failed): - - if task.result is None: - exc = dbt.exceptions.InternalException( - f'At end of task {task_id}, state={state} but result is ' - 'None' - ) - raise RPCException.from_error( - dbt_error(exc, logs=[l.to_dict() for l in task_logs]) - ) - return poll_complete( - status=state, - result=task.result, - tags=task.tags, - ) - elif state == TaskHandlerState.Killed: - return PollKilledResult( - status=state, tags=task.tags, logs=task_logs - ) - else: - exc = dbt.exceptions.InternalException( - f'Got unknown value state={state} for task {task_id}' - ) - raise RPCException.from_error( - dbt_error(exc, logs=[l.to_dict() for l in task_logs]) - ) - - def _rpc_builtins(self) -> Dict[str, UnmanagedHandler]: - if self._builtins: - return self._builtins - - with self._lock: - if self._builtins: # handle a race - return self._builtins - - methods: Dict[str, UnmanagedHandler] = { - 'kill': self.process_kill, - 'ps': self.process_ps, - 'status': self.process_status, - 'poll': self.process_poll, - 'gc': self.process_gc, - } - - self._builtins.update(methods) - return self._builtins - - def methods(self, builtin=True) -> Set[str]: - all_methods: Set[str] = set() - if builtin: - all_methods.update(self._rpc_builtins()) - - with self._lock: - all_methods.update(self._rpc_task_map) - - return all_methods + return set(self._task_types) def currently_compiling(self, *args, **kwargs): """Raise an RPC exception to trigger the error handler.""" @@ -499,165 +280,37 @@ def compilation_error(self, *args, **kwargs): dbt.exceptions.RPCLoadException(self.last_parse.error) ) - def internal_error_for(self, msg) -> WrappedHandler: - def _error(*args, **kwargs): - raise dbt.exceptions.InternalException(msg) - return _wrap_builtin(_error) - def get_handler( self, method, http_request, json_rpc_request ) -> Optional[Union[WrappedHandler, RemoteMethod]]: # get_handler triggers a GC check. TODO: does this go somewhere else? self.gc_as_required() - # the dispatcher's keys are method names and its values are functions - # that implement the RPC calls - _builtins = self._rpc_builtins() - if method in _builtins: - return _wrap_builtin(_builtins[method]) - elif method not in self._rpc_task_map: + + if method not in self._task_types: return None task = self.rpc_task(method) - # If the task we got back was reserved, it must be a task that requires - # a manifest and we don't have one. So we had better have a state of - # compiling or error. - - if isinstance(task, Reserved): - status = self.last_parse.status - if status == ManifestStatus.Compiling: - return self.currently_compiling - elif status == ManifestStatus.Error: - return self.compilation_error - else: - # if we got here, there is an error in dbt :( - return self.internal_error_for( - f'Got a None task for {method}, state is {status}' - ) - else: - return task - - def _remove_task_if_finished(self, task_id: uuid.UUID) -> GCResultState: - """Remove the task if it was finished. Raises a KeyError if the entry - is removed during operation (so hold the lock). - """ - if task_id not in self.active_tasks: - return GCResultState.Missing - - task = self.active_tasks[task_id] - if not task.state.finished: - return GCResultState.Running - - del self.active_tasks[task_id] - return GCResultState.Deleted + return task - def _gc_task_id(self, result: GCResultSet, task_id: uuid.UUID) -> None: - """To 'gc' a task ID, we just delete it from the tasks dict. - - You must hold the lock, as this mutates `tasks`. - """ - try: - status = self._remove_task_if_finished(task_id) - except KeyError: - # someone was mutating tasks while we had the lock, that's - # not right! - raise dbt.exceptions.InternalException( - 'Got a KeyError for task uuid={} during gc' - .format(task_id) - ) + def task_table(self) -> List[TaskRow]: + rows: List[TaskRow] = [] + now = datetime.utcnow() + with self._lock: + for task in self.active_tasks.values(): + rows.append(make_task(task, now)) + return rows - return result.add_result(task_id=task_id, status=status) - - def _get_gc_before_list(self, when: datetime) -> List[uuid.UUID]: - removals: List[uuid.UUID] = [] - for task in self.active_tasks.values(): - if not task.state.finished: - continue - elif task.ended is None: - continue - elif task.ended < when: - removals.append(task.task_id) - - return removals - - def _get_oldest_ended_list(self, num: int) -> List[uuid.UUID]: - candidates: List[Tuple[datetime, uuid.UUID]] = [] - for task in self.active_tasks.values(): - if not task.state.finished: - continue - elif task.ended is None: - continue - else: - candidates.append((task.ended, task.task_id)) - candidates.sort(key=operator.itemgetter(0)) - return [task_id for _, task_id in candidates[:num]] - - def _gc_multiple_task_ids( - self, task_ids: Iterable[uuid.UUID] - ) -> GCResultSet: - result = GCResultSet() - for task_id in task_ids: - self._gc_task_id(result, task_id) - return result + def gc_as_required(self) -> None: + with self._lock: + return self.gc.collect_as_required() def gc_safe( self, task_ids: Optional[List[uuid.UUID]] = None, before: Optional[datetime] = None, - ) -> GCResultSet: - to_gc = set() - - if task_ids is not None: - to_gc.update(task_ids) - - with self._lock: - # we need the lock for this! - if before is not None: - to_gc.update(self._get_gc_before_list(before)) - return self._gc_multiple_task_ids(to_gc) - - def _gc_as_required_unsafe(self) -> None: - to_remove: List[uuid.UUID] = [] - num_tasks = len(self.active_tasks) - if num_tasks > self._gc_settings.maxsize: - num = self._gc_settings.maxsize - num_tasks - to_remove = self._get_oldest_ended_list(num) - elif num_tasks > self._gc_settings.reapsize: - before = datetime.utcnow() - self._gc_settings.auto_reap_age - to_remove = self._get_gc_before_list(before) - - if to_remove: - self._gc_multiple_task_ids(to_remove) - - def gc_as_required(self) -> None: + settings: Optional[GCSettings] = None, + ) -> GCResult: with self._lock: - return self._gc_as_required_unsafe() - - def process_gc( - self, - task_ids: Optional[List[str]] = None, - before: Optional[str] = None, - settings: Optional[Dict[str, Any]] = None, - ) -> GCResultSet: - """The gc endpoint takes three arguments, any of which may be present: - - - task_ids: An optional list of task ID UUIDs to try to GC - - before: If provided, should be a datetime string. All tasks that - finished before that datetime will be GCed - - settings: If provided, should be a GCSettings object in JSON form. - It will be applied to the task manager before GC starts. By default - the existing gc settings remain. - """ - try: - args = _GCArguments.from_dict({ - 'task_ids': task_ids, - 'before': before, - 'settings': settings, - }) - except ValidationError as exc: - # trigger the jsonrpc library to recognize the arguments as bad - raise TypeError('bad arguments: {}'.format(exc)) - - if args.settings: - self._gc_settings = args.settings - - return self.gc_safe(task_ids=args.task_ids, before=args.before) + return self.gc.collect_selected( + task_ids=task_ids, before=before, settings=settings, + ) diff --git a/core/dbt/task/rpc/server.py b/core/dbt/task/rpc/server.py index 8a3a963f43e..04b0eaeba3d 100644 --- a/core/dbt/task/rpc/server.py +++ b/core/dbt/task/rpc/server.py @@ -19,7 +19,7 @@ log_manager, ) from dbt.rpc.logger import ServerContext, HTTPRequest, RPCResponse -from dbt.rpc.method import TaskList +from dbt.rpc.method import TaskTypes from dbt.rpc.response_manager import ResponseManager from dbt.rpc.task_manager import TaskManager from dbt.task.base import ConfiguredTask @@ -73,17 +73,15 @@ def signhup_replace() -> Iterator[bool]: class RPCServerTask(ConfiguredTask): DEFAULT_LOG_FORMAT = 'json' - def __init__(self, args, config, tasks: Optional[TaskList] = None): + def __init__(self, args, config, tasks: Optional[TaskTypes] = None): if os.name == 'nt': raise RuntimeException( 'The dbt RPC server is not supported on windows' ) super().__init__(args, config) self.task_manager = TaskManager( - self.args, self.config, TaskList(tasks) + self.args, self.config, TaskTypes(tasks) ) - self.task_manager.reload_non_manifest_tasks() - self.task_manager.reload_manifest_tasks() signal.signal(signal.SIGHUP, self._sighup_handler) @classmethod @@ -100,10 +98,7 @@ def _sighup_handler(self, signum, frame): # a sighup handler is already active. return self.task_manager.reload_config() - self.task_manager.reload_manifest_tasks() - - def single_threaded(self): - return self.task_manager.single_threaded() + self.task_manager.reload_manifest() def run_forever(self): host = self.args.host