Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: async rpc (#1706) #1735

Merged
merged 10 commits into from
Sep 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _compile(self, source, filename):
If the value is 'write', also write the files to disk.
WARNING: This can write a ton of data if you aren't careful.
"""
macro_compile = os.environ.get('DBT_MACRO_DEBUGGING')
macro_compile = dbt.utils.env_set_truthy('DBT_MACRO_DEBUGGING')
if filename == '<template>' and macro_compile:
write = macro_compile == 'write'
filename = _linecache_inject(source, write)
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/config/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
DEFAULT_THREADS = 1
DEFAULT_PROFILES_DIR = os.path.join(os.path.expanduser('~'), '.dbt')
PROFILES_DIR = os.path.expanduser(
os.environ.get('DBT_PROFILES_DIR', DEFAULT_PROFILES_DIR)
os.getenv('DBT_PROFILES_DIR', DEFAULT_PROFILES_DIR)
Copy link
Contributor Author

@beckjake beckjake Sep 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed all os.environ.get -> os.getenv for grep-ability reasons when I was adding some secret magic threading-related flags to the RPC server.

)

INVALID_PROFILE_MESSAGE = """
Expand Down
25 changes: 25 additions & 0 deletions core/dbt/contracts/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dbt.contracts.graph.unparsed import Time, FreshnessStatus
from dbt.contracts.graph.parsed import ParsedSourceDefinition
from dbt.contracts.util import Writable
from dbt.logger import LogMessage
from hologram.helpers import StrEnum
from hologram import JsonSchemaMixin

Expand Down Expand Up @@ -77,6 +78,15 @@ class ExecutionResult(JsonSchemaMixin, Writable):
generated_at: datetime
elapsed_time: Real

def __len__(self):
return len(self.results)

def __iter__(self):
return iter(self.results)

def __getitem__(self, idx):
return self.results[idx]


# due to issues with typing.Union collapsing subclasses, this can't subclass
# PartialResult
Expand Down Expand Up @@ -137,6 +147,15 @@ def write(self, path, omit_none=True):
output = FreshnessRunOutput(meta=meta, sources=sources)
output.write(path, omit_none=omit_none)

def __len__(self):
return len(self.results)

def __iter__(self):
return iter(self.results)

def __getitem__(self, idx):
return self.results[idx]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes make results a lot like the old thing where we returned result.results, which meant I didn't have to update a million tests. Also, it just seemed like a nice thing.



def _copykeys(src, keys, **updates):
return {k: getattr(src, k) for k in keys}
Expand Down Expand Up @@ -183,12 +202,18 @@ class RemoteCompileResult(JsonSchemaMixin):
compiled_sql: str
node: CompileResultNode
timing: List[TimingInfo]
logs: List[LogMessage]

@property
def error(self):
return None


@dataclass
class RemoteExecutionResult(ExecutionResult):
logs: List[LogMessage]


@dataclass
class ResultTable(JsonSchemaMixin):
column_names: List[str]
Expand Down
11 changes: 11 additions & 0 deletions core/dbt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,17 @@ def __reduce__(self):
return (JSONValidationException, (self.typename, self.errors))


class UnknownAsyncIDException(Exception):
CODE = 10012
MESSAGE = 'RPC server got an unknown async ID'

def __init__(self, task_id):
self.task_id = task_id

def __str__(self):
return '{}: {}'.format(self.MESSAGE, self.task_id)


class AliasException(ValidationException):
pass

Expand Down
32 changes: 30 additions & 2 deletions core/dbt/helper_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# never name this package "types", or mypy will crash in ugly ways
from hologram import FieldEncoder, JsonSchemaMixin
from hologram import (
FieldEncoder, JsonSchemaMixin, JsonDict, ValidationError
)

from datetime import timedelta
from typing import NewType


Expand All @@ -12,4 +16,28 @@ def json_schema(self):
return {'type': 'integer', 'minimum': 0, 'maximum': 65535}


JsonSchemaMixin.register_field_encoders({Port: PortEncoder()})
class TimeDeltaFieldEncoder(FieldEncoder[timedelta]):
"""Encodes timedeltas to dictionaries"""

def to_wire(self, value: timedelta) -> float:
return value.total_seconds()

def to_python(self, value) -> timedelta:
if isinstance(value, timedelta):
return value
try:
return timedelta(seconds=value)
except TypeError:
raise ValidationError(
'cannot encode {} into timedelta'.format(value)
) from None

@property
def json_schema(self) -> JsonDict:
return {'type': 'number'}


JsonSchemaMixin.register_field_encoders({
Port: PortEncoder(),
timedelta: TimeDeltaFieldEncoder()
})
64 changes: 41 additions & 23 deletions core/dbt/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import Optional, List, ContextManager, Callable, Dict, Any
from typing import Optional, List, ContextManager, Callable, Dict, Any, Set

import colorama
import logbook
Expand All @@ -25,7 +25,7 @@
colorama.init(wrap=colorama_wrap)


if sys.platform == 'win32' and not os.environ.get('TERM'):
if sys.platform == 'win32' and not os.getenv('TERM'):
colorama_wrap = False
colorama_stdout = colorama.AnsiToWin32(sys.stdout).stream

Expand Down Expand Up @@ -157,6 +157,35 @@ def _redirect_std_logging():
logbook.compat.redirect_logging()


def _root_channel(record: logbook.LogRecord) -> str:
return record.channel.split('.')[0]


class Relevel(logbook.Processor):
def __init__(
self,
allowed: List[str],
min_level=logbook.WARNING,
target_level=logbook.DEBUG,
) -> None:
self.allowed: Set[str] = set(allowed)
self.min_level = min_level
self.target_level = target_level
super().__init__()

def process(self, record):
if _root_channel(record) in self.allowed:
return
record.extra['old_level'] = record.level
# suppress logs at/below our min level by lowering them to NOTSET
if record.level < self.min_level:
record.level = logbook.NOTSET
# if we didn't mess with it, then lower all logs above our level to
# our target level.
else:
record.level = self.target_level


logger = logbook.Logger('dbt')
# provide this for the cache, disabled by default
CACHE_LOGGER = logbook.Logger('dbt.cache')
Expand Down Expand Up @@ -287,10 +316,12 @@ def __init__(self, stdout=colorama_stdout, stderr=sys.stderr):
self._null_handler = logbook.NullHandler()
self._output_handler = OutputHandler(self.stdout)
self._file_handler = DelayedFileHandler()
self._relevel_processor = Relevel(allowed=['dbt', 'werkzeug'])
super().__init__([
self._null_handler,
self._output_handler,
self._file_handler,
self._relevel_processor,
])

def disable(self):
Expand Down Expand Up @@ -388,28 +419,15 @@ def __init__(
lst = []
self.records: List[LogMessage] = lst

def emit(self, record: logbook.LogRecord):
as_dict = self.format_logmessage(record)
self.records.append(as_dict)


class SuppressBelow(logbook.Handler):
def __init__(
self, channels, level=logbook.INFO, filter=None, bubble=False
) -> None:
self.channels = set(channels)
super().__init__(level, filter, bubble)

def should_handle(self, record):
channel = record.channel.split('.')[0]
if channel not in self.channels:
"""Only ever emit dbt-sourced log messages to the ListHandler."""
if _root_channel(record) != 'dbt':
return False
# if we were set to 'info' and record.level is warn/error, we don't
# want to 'handle' it (so a real logger will)
return self.level >= record.level
return super().should_handle(record)

def handle(self, record):
return True
def emit(self, record: logbook.LogRecord):
as_dict = self.format_logmessage(record)
self.records.append(as_dict)


# we still need to use logging to suppress these or pytest captures them
Expand All @@ -419,8 +437,8 @@ def handle(self, record):
logging.getLogger('google').setLevel(logging.INFO)
logging.getLogger('snowflake.connector').setLevel(logging.INFO)
logging.getLogger('parsedatetime').setLevel(logging.INFO)
# we never want to see werkzeug logs
logging.getLogger('werkzeug').setLevel(logging.CRITICAL)
# want to see werkzeug logs about errors
logging.getLogger('werkzeug').setLevel(logging.ERROR)


def list_handler(
Expand Down
39 changes: 38 additions & 1 deletion core/dbt/rpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,43 @@
"""The `rpc` package handles most aspects of the actual execution of dbt's RPC
server (except for the server itself and the client tasks, which are defined in
the `task.remote` and `task.rpc_server` modules).

The general idea from a thread/process management perspective (ignoring the
--single-threaded flag!) is as follows:

- The RPC server runs a web server, in particular `werkzeug`, which manages a
thread pool.
- When a request comes in, werkzeug spins off a thread to manage the
request/response portion. dbt itself has basically no control over this
operation - from our viewpoint request/response cycles are fully
synchronous.
- synchronous requests are defined as methods in the `TaskManager` and handled
in the responding thread directly.
- Asynchronous requests (defined in `tasks.remote`) are kicked off wrapped in
`RequestTaskHandler`s, which manage a new process and a new thread.
- The process runs the actual dbt request, logging via a message queue
- eventually just before process exit, the process places an "error" or
"result" on the queue
- The thread monitors the queue, taking logs off the queue and adding them
to the `RequestTaskHandler`'s `logs` attribute.
- The thread also monitors the `is_alive` state of the process, in case
it is killed "unexpectedly" (including via `kill`)
- When the thread sees an error or result come over the queue, it join()s
the process.
- When the thread sees that the process has disappeared without placing
anything on the queue, it checks the queue one last time, and then acts
as if the queue received an 'Unexpected termination' error
- `kill` commands pointed at an asynchronous task kill the process and allow
the thread to handle cleanup and management
- When the RPC server receives a shutdown instruction, it:
- stops responding to requests
- `kills` all processes (triggering the end of all processes, right!?)
- exits (all remaining threads should die here!)
"""
from dbt.rpc.error import ( # noqa
dbt_error, server_error, invalid_params, RPCException
)
from dbt.rpc.task import RemoteCallable, RemoteCallableResult # noqa
from dbt.rpc.logger import RemoteCallableResult # noqa
from dbt.rpc.task import RemoteCallable # noqa
from dbt.rpc.task_manager import TaskManager # noqa
from dbt.rpc.response_manager import ResponseManager # noqa
15 changes: 12 additions & 3 deletions core/dbt/rpc/error.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from typing import List, Dict, Any, Optional

from jsonrpc.exceptions import JSONRPCDispatchException, JSONRPCInvalidParams

import dbt.exceptions


class RPCException(JSONRPCDispatchException):
def __init__(self, code=None, message=None, data=None, logs=None):
def __init__(
self,
code: Optional[int] = None,
message: Optional[str] = None,
data: Optional[Dict[str, Any]] = None,
logs: Optional[List[Dict[str, Any]]] = None,
) -> None:
if code is None:
code = -32000
if message is None:
Expand All @@ -13,7 +21,8 @@ def __init__(self, code=None, message=None, data=None, logs=None):
data = {}

super().__init__(code=code, message=message, data=data)
self.logs = logs
if logs is not None:
self.logs = logs

def __str__(self):
return (
Expand All @@ -22,7 +31,7 @@ def __str__(self):
)

@property
def logs(self):
def logs(self) -> List[Dict[str, Any]]:
return self.error.data.get('logs')

@logs.setter
Expand Down
Loading