diff --git a/core/dbt/contracts/rpc.py b/core/dbt/contracts/rpc.py index 39566fe0d93..2c7c6390df5 100644 --- a/core/dbt/contracts/rpc.py +++ b/core/dbt/contracts/rpc.py @@ -41,12 +41,14 @@ class RPCExecParameters(RPCParameters): @dataclass class RPCCompileParameters(RPCParameters): + threads: Optional[int] = None models: Union[None, str, List[str]] = None exclude: Union[None, str, List[str]] = None @dataclass class RPCSnapshotParameters(RPCParameters): + threads: Optional[int] = None select: Union[None, str, List[str]] = None exclude: Union[None, str, List[str]] = None @@ -59,6 +61,7 @@ class RPCTestParameters(RPCCompileParameters): @dataclass class RPCSeedParameters(RPCParameters): + threads: Optional[int] = None show: bool = False diff --git a/core/dbt/rpc/task_handler.py b/core/dbt/rpc/task_handler.py index adfd0e5e7eb..0caa6b8c6ac 100644 --- a/core/dbt/rpc/task_handler.py +++ b/core/dbt/rpc/task_handler.py @@ -86,6 +86,10 @@ def task_exec(self) -> None: handler = QueueLogHandler(self.queue) with handler.applicationbound(): self._spawn_setup() + # copy threads over into our credentials, if it exists and is set. + # some commands, like 'debug', won't have a threads value at all. + if getattr(self.task.args, 'threads', None) is not None: + self.task.config.threads = self.task.args.threads rpc_exception = None result = None try: diff --git a/core/dbt/task/rpc/project_commands.py b/core/dbt/task/rpc/project_commands.py index 68ed3c02b8d..501cc17373b 100644 --- a/core/dbt/task/rpc/project_commands.py +++ b/core/dbt/task/rpc/project_commands.py @@ -55,6 +55,8 @@ class RemoteCompileProjectTask( def set_args(self, params: RPCCompileParameters) -> None: self.args.models = self._listify(params.models) self.args.exclude = self._listify(params.exclude) + if params.threads is not None: + self.args.threads = params.threads class RemoteRunProjectTask(RPCCommandTask[RPCCompileParameters], RunTask): @@ -63,12 +65,16 @@ class RemoteRunProjectTask(RPCCommandTask[RPCCompileParameters], RunTask): def set_args(self, params: RPCCompileParameters) -> None: self.args.models = self._listify(params.models) self.args.exclude = self._listify(params.exclude) + if params.threads is not None: + self.args.threads = params.threads class RemoteSeedProjectTask(RPCCommandTask[RPCSeedParameters], SeedTask): METHOD_NAME = 'seed' def set_args(self, params: RPCSeedParameters) -> None: + if params.threads is not None: + self.args.threads = params.threads self.args.show = params.show @@ -80,6 +86,8 @@ def set_args(self, params: RPCTestParameters) -> None: self.args.exclude = self._listify(params.exclude) self.args.data = params.data self.args.schema = params.schema + if params.threads is not None: + self.args.threads = params.threads class RemoteDocsGenerateProjectTask( @@ -140,3 +148,5 @@ def set_args(self, params: RPCSnapshotParameters) -> None: # select has an argparse `dest` value of `models`. self.args.models = self._listify(params.select) self.args.exclude = self._listify(params.exclude) + if params.threads is not None: + self.args.threads = params.threads diff --git a/core/dbt/task/rpc/server.py b/core/dbt/task/rpc/server.py index 04b0eaeba3d..2e8ecd15d31 100644 --- a/core/dbt/task/rpc/server.py +++ b/core/dbt/task/rpc/server.py @@ -123,8 +123,7 @@ def run_forever(self): 'Send requests to http://{}:{}/jsonrpc'.format(display_host, port) ) - app = self.handle_request - app = DispatcherMiddleware(app, { + app = DispatcherMiddleware(self.handle_request, { '/jsonrpc': self.handle_jsonrpc_request, }) diff --git a/test/rpc/test_base.py b/test/rpc/test_base.py index ee003a4107c..50d1c2012dd 100644 --- a/test/rpc/test_base.py +++ b/test/rpc/test_base.py @@ -1,5 +1,6 @@ # flake8: disable=redefined-outer-name import time +import yaml from .util import ( ProjectDefinition, rpc_server, Querier, built_schema, get_querier, ) @@ -458,3 +459,137 @@ def test_snapshots_cli(project_root, profiles_root, postgres_profile, unique_sch token = querier.is_async_result(querier.cli_args(cli='snapshot --select=snapshot_actual')) results = querier.is_result(querier.async_wait(token)) assert len(results['results']) == 1 + + +def assert_has_threads(results, num_threads): + assert 'logs' in results + c_logs = [l for l in results['logs'] if 'Concurrency' in l['message']] + assert len(c_logs) == 1, \ + f'Got invalid number of concurrency logs ({len(c_logs)})' + assert 'message' in c_logs[0] + assert f'Concurrency: {num_threads} threads' in c_logs[0]['message'] + + +def test_rpc_run_threads(project_root, profiles_root, postgres_profile, unique_schema): + project = ProjectDefinition( + models={'my_model.sql': 'select 1 as id'} + ) + querier_ctx = get_querier( + project_def=project, + project_dir=project_root, + profiles_dir=profiles_root, + schema=unique_schema, + test_kwargs={}, + ) + with querier_ctx as querier: + token = querier.is_async_result(querier.run(threads=5)) + results = querier.is_result(querier.async_wait(token)) + assert_has_threads(results, 5) + + token = querier.is_async_result(querier.cli_args('run --threads=7')) + results = querier.is_result(querier.async_wait(token)) + assert_has_threads(results, 7) + + +def test_rpc_compile_threads(project_root, profiles_root, postgres_profile, unique_schema): + project = ProjectDefinition( + models={'my_model.sql': 'select 1 as id'} + ) + querier_ctx = get_querier( + project_def=project, + project_dir=project_root, + profiles_dir=profiles_root, + schema=unique_schema, + test_kwargs={}, + ) + with querier_ctx as querier: + token = querier.is_async_result(querier.compile(threads=5)) + results = querier.is_result(querier.async_wait(token)) + assert_has_threads(results, 5) + + token = querier.is_async_result(querier.cli_args('compile --threads=7')) + results = querier.is_result(querier.async_wait(token)) + assert_has_threads(results, 7) + + +def test_rpc_test_threads(project_root, profiles_root, postgres_profile, unique_schema): + schema_yaml = { + 'version': 2, + 'models': [{ + 'name': 'my_model', + 'columns': [ + { + 'name': 'id', + 'tests': ['not_null', 'unique'], + }, + ], + }], + } + project = ProjectDefinition( + models={ + 'my_model.sql': 'select 1 as id', + 'schema.yml': yaml.safe_dump(schema_yaml)} + ) + querier_ctx = get_querier( + project_def=project, + project_dir=project_root, + profiles_dir=profiles_root, + schema=unique_schema, + test_kwargs={}, + ) + with querier_ctx as querier: + # first run dbt to get the model built + token = querier.is_async_result(querier.run()) + querier.is_result(querier.async_wait(token)) + + token = querier.is_async_result(querier.test(threads=5)) + results = querier.is_result(querier.async_wait(token)) + assert_has_threads(results, 5) + + token = querier.is_async_result(querier.cli_args('test --threads=7')) + results = querier.is_result(querier.async_wait(token)) + assert_has_threads(results, 7) + + +def test_rpc_snapshot_threads(project_root, profiles_root, postgres_profile, unique_schema): + project = ProjectDefinition( + snapshots={'my_snapshots.sql': snapshot_data}, + ) + querier_ctx = get_querier( + project_def=project, + project_dir=project_root, + profiles_dir=profiles_root, + schema=unique_schema, + test_kwargs={}, + ) + + with querier_ctx as querier: + token = querier.is_async_result(querier.snapshot(threads=5)) + results = querier.is_result(querier.async_wait(token)) + assert_has_threads(results, 5) + + token = querier.is_async_result(querier.cli_args('snapshot --threads=7')) + results = querier.is_result(querier.async_wait(token)) + assert_has_threads(results, 7) + + +def test_rpc_seed_threads(project_root, profiles_root, postgres_profile, unique_schema): + project = ProjectDefinition( + seeds={'data.csv': 'a,b\n1,hello\n2,goodbye'} + ) + querier_ctx = get_querier( + project_def=project, + project_dir=project_root, + profiles_dir=profiles_root, + schema=unique_schema, + test_kwargs={}, + ) + + with querier_ctx as querier: + token = querier.is_async_result(querier.seed(threads=5)) + results = querier.is_result(querier.async_wait(token)) + assert_has_threads(results, 5) + + token = querier.is_async_result(querier.cli_args('seed --threads=7')) + results = querier.is_result(querier.async_wait(token)) + assert_has_threads(results, 7) diff --git a/test/rpc/util.py b/test/rpc/util.py index 0449db57101..77074b00aef 100644 --- a/test/rpc/util.py +++ b/test/rpc/util.py @@ -193,6 +193,7 @@ def compile( self, models: Optional[Union[str, List[str]]] = None, exclude: Optional[Union[str, List[str]]] = None, + threads: Optional[int] = None, request_id: int = 1, ): params = {} @@ -200,6 +201,8 @@ def compile( params['models'] = models if exclude is not None: params['exclude'] = exclude + if threads is not None: + params['threads'] = threads return self.request( method='compile', params=params, request_id=request_id ) @@ -208,6 +211,7 @@ def run( self, models: Optional[Union[str, List[str]]] = None, exclude: Optional[Union[str, List[str]]] = None, + threads: Optional[int] = None, request_id: int = 1, ): params = {} @@ -215,6 +219,8 @@ def run( params['models'] = models if exclude is not None: params['exclude'] = exclude + if threads is not None: + params['threads'] = threads return self.request( method='run', params=params, request_id=request_id ) @@ -232,10 +238,17 @@ def run_operation( method='run-operation', params=params, request_id=request_id ) - def seed(self, show: bool = None, request_id: int = 1): + def seed( + self, + show: bool = None, + threads: Optional[int] = None, + request_id: int = 1, + ): params = {} if show is not None: params['show'] = show + if threads is not None: + params['threads'] = threads return self.request( method='seed', params=params, request_id=request_id ) @@ -244,6 +257,7 @@ def snapshot( self, select: Optional[Union[str, List[str]]] = None, exclude: Optional[Union[str, List[str]]] = None, + threads: Optional[int] = None, request_id: int = 1, ): params = {} @@ -251,6 +265,8 @@ def snapshot( params['select'] = select if exclude is not None: params['exclude'] = exclude + if threads is not None: + params['threads'] = threads return self.request( method='snapshot', params=params, request_id=request_id ) @@ -259,6 +275,7 @@ def test( self, models: Optional[Union[str, List[str]]] = None, exclude: Optional[Union[str, List[str]]] = None, + threads: Optional[int] = None, data: bool = None, schema: bool = None, request_id: int = 1, @@ -272,6 +289,8 @@ def test( params['data'] = data if schema is not None: params['schema'] = schema + if threads is not None: + params['threads'] = threads return self.request( method='test', params=params, request_id=request_id ) @@ -406,6 +425,7 @@ def __init__( models=None, macros=None, snapshots=None, + seeds=None, ): self.project = { 'name': name, @@ -418,10 +438,11 @@ def __init__( self.models = models self.macros = macros self.snapshots = snapshots + self.seeds = seeds def _write_recursive(self, path, inputs): for name, value in inputs.items(): - if name.endswith('.sql'): + if name.endswith('.sql') or name.endswith('.csv'): path.join(name).write(value) elif name.endswith('.yml'): if isinstance(value, str): @@ -464,6 +485,9 @@ def write_macros(self, project_dir, remove=False): def write_snapshots(self, project_dir, remove=False): self._write_values(project_dir, remove, 'snapshots', self.snapshots) + def write_seeds(self, project_dir, remove=False): + self._write_values(project_dir, remove, 'data', self.seeds) + def write_to(self, project_dir, remove=False): if remove: project_dir.remove() @@ -473,6 +497,7 @@ def write_to(self, project_dir, remove=False): self.write_models(project_dir) self.write_macros(project_dir) self.write_snapshots(project_dir) + self.write_seeds(project_dir) class TestArgs: