Skip to content

Commit

Permalink
Merge pull request #2736 from fishtown-analytics/feature/rpc-state-defer
Browse files Browse the repository at this point in the history
Feature: state and defer in RPC calls
  • Loading branch information
beckjake authored Sep 8, 2020
2 parents ae542dc + 60f4c96 commit 1fa149d
Show file tree
Hide file tree
Showing 11 changed files with 358 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
## dbt 0.19.0 (Release TBD)

### Features
- Added state and defer arguments to the RPC client, matching the CLI ([#2678](https://github.com/fishtown-analytics/dbt/issues/2678), [#2736](https://github.com/fishtown-analytics/dbt/pull/2736))

## dbt 0.18.0 (September 03, 2020)

### Under the hood
Expand Down
15 changes: 15 additions & 0 deletions core/dbt/contracts/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ class RPCCompileParameters(RPCParameters):
models: Union[None, str, List[str]] = None
exclude: Union[None, str, List[str]] = None
selector: Optional[str] = None
state: Optional[str] = None


@dataclass
class RPCRunParameters(RPCParameters):
threads: Optional[int] = None
models: Union[None, str, List[str]] = None
exclude: Union[None, str, List[str]] = None
selector: Optional[str] = None
state: Optional[str] = None
defer: Optional[bool] = None


@dataclass
Expand All @@ -53,12 +64,14 @@ class RPCSnapshotParameters(RPCParameters):
select: Union[None, str, List[str]] = None
exclude: Union[None, str, List[str]] = None
selector: Optional[str] = None
state: Optional[str] = None


@dataclass
class RPCTestParameters(RPCCompileParameters):
data: bool = False
schema: bool = False
state: Optional[str] = None


@dataclass
Expand All @@ -68,11 +81,13 @@ class RPCSeedParameters(RPCParameters):
exclude: Union[None, str, List[str]] = None
selector: Optional[str] = None
show: bool = False
state: Optional[str] = None


@dataclass
class RPCDocsGenerateParameters(RPCParameters):
compile: bool = True
state: Optional[str] = None


@dataclass
Expand Down
39 changes: 37 additions & 2 deletions core/dbt/task/rpc/project_commands.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Union

from dbt import flags
from dbt.contracts.graph.manifest import WritableManifest
from dbt.contracts.rpc import (
GetManifestParameters,
GetManifestResult,
RPCCompileParameters,
RPCDocsGenerateParameters,
RPCRunParameters,
RPCRunOperationParameters,
RPCSeedParameters,
RPCTestParameters,
Expand Down Expand Up @@ -54,6 +57,15 @@ def handle_request(self) -> RemoteExecutionResult:
return self.run()


def state_path(state: Optional[str]) -> Optional[Path]:
if state is not None:
return Path(state)
elif flags.ARTIFACT_STATE_PATH is not None:
return Path(flags.ARTIFACT_STATE_PATH)
else:
return None


class RemoteCompileProjectTask(
RPCCommandTask[RPCCompileParameters], CompileTask
):
Expand All @@ -66,16 +78,28 @@ def set_args(self, params: RPCCompileParameters) -> None:
if params.threads is not None:
self.args.threads = params.threads

self.args.state = state_path(params.state)

self.set_previous_state()

class RemoteRunProjectTask(RPCCommandTask[RPCCompileParameters], RunTask):

class RemoteRunProjectTask(RPCCommandTask[RPCRunParameters], RunTask):
METHOD_NAME = 'run'

def set_args(self, params: RPCCompileParameters) -> None:
def set_args(self, params: RPCRunParameters) -> None:
self.args.models = self._listify(params.models)
self.args.exclude = self._listify(params.exclude)
self.args.selector_name = params.selector

if params.threads is not None:
self.args.threads = params.threads
if params.defer is None:
self.args.defer = flags.DEFER_MODE
else:
self.args.defer = params.defer

self.args.state = state_path(params.state)
self.set_previous_state()


class RemoteSeedProjectTask(RPCCommandTask[RPCSeedParameters], SeedTask):
Expand All @@ -90,6 +114,9 @@ def set_args(self, params: RPCSeedParameters) -> None:
self.args.threads = params.threads
self.args.show = params.show

self.args.state = state_path(params.state)
self.set_previous_state()


class RemoteTestProjectTask(RPCCommandTask[RPCTestParameters], TestTask):
METHOD_NAME = 'test'
Expand All @@ -103,6 +130,9 @@ def set_args(self, params: RPCTestParameters) -> None:
if params.threads is not None:
self.args.threads = params.threads

self.args.state = state_path(params.state)
self.set_previous_state()


class RemoteDocsGenerateProjectTask(
RPCCommandTask[RPCDocsGenerateParameters],
Expand All @@ -116,6 +146,8 @@ def set_args(self, params: RPCDocsGenerateParameters) -> None:
self.args.selector_name = None
self.args.compile = params.compile

self.args.state = state_path(params.state)

def get_catalog_results(
self, nodes, sources, generated_at, compile_results, errors
) -> RemoteCatalogResults:
Expand Down Expand Up @@ -185,6 +217,9 @@ def set_args(self, params: RPCSnapshotParameters) -> None:
if params.threads is not None:
self.args.threads = params.threads

self.args.state = state_path(params.state)
self.set_previous_state()


class RemoteSourceFreshnessTask(
RPCCommandTask[RPCSourceFreshnessParameters],
Expand Down
3 changes: 3 additions & 0 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def __init__(self, args, config):
self._skipped_children = {}
self._raise_next_tick = None
self.previous_state: Optional[PreviousState] = None
self.set_previous_state()

def set_previous_state(self):
if self.args.state is not None:
self.previous_state = PreviousState(self.args.state)

Expand Down
10 changes: 10 additions & 0 deletions test/rpc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ def postgres_profile_data(unique_schema):
'dbname': 'dbt',
'schema': unique_schema,
},
'other_schema': {
'type': 'postgres',
'threads': 4,
'host': 'database',
'port': 5432,
'user': 'root',
'pass': 'password',
'dbname': 'dbt',
'schema': unique_schema+'_alt',
}
},
'target': 'default'
}
Expand Down
44 changes: 44 additions & 0 deletions test/rpc/test_compile.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import pytest
from .util import (
assert_has_threads,
get_querier,
get_write_manifest,
ProjectDefinition,
)

Expand All @@ -28,3 +30,45 @@ def test_rpc_compile_threads(
querier.cli_args('compile --threads=7')
)
assert_has_threads(results, 7)


@pytest.mark.supported('postgres')
def test_rpc_compile_state(
project_root, profiles_root, dbt_profile, unique_schema
):
project = ProjectDefinition(models={'my_model.sql': 'select 1 as id'})
querier_ctx = get_querier(
project_def=project,
project_dir=project_root,
profiles_dir=profiles_root,
schema=unique_schema,
test_kwargs={},
)
with querier_ctx as querier:
state_dir = os.path.join(project_root, 'state')
os.makedirs(state_dir)

results = querier.async_wait_for_result(
querier.compile()
)
assert len(results['results']) == 1

get_write_manifest(querier, os.path.join(state_dir, 'manifest.json'))

project.models['my_model.sql'] = 'select 2 as id'
project.write_models(project_root, remove=True)

querier.sighup()
assert querier.wait_for_status('ready') is True

results = querier.async_wait_for_result(
querier.compile(state='./state', models=['state:modified'])
)
assert len(results['results']) == 1

get_write_manifest(querier, os.path.join(state_dir, 'manifest.json'))

results = querier.async_wait_for_result(
querier.compile(state='./state', models=['state:modified']),
)
assert len(results['results']) == 0
55 changes: 55 additions & 0 deletions test/rpc/test_run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import pytest
from .util import (
assert_has_threads,
get_querier,
get_write_manifest,
ProjectDefinition,
)

Expand Down Expand Up @@ -79,3 +81,56 @@ def test_rpc_run_vars_compiled(
results = querier.async_wait_for_result(querier.cli_args('run'))
assert len(results['results']) == 1
assert results['results'][0]['node']['config']['materialized'] == 'view'


@pytest.mark.supported('postgres')
def test_rpc_run_state_defer(
project_root, profiles_root, dbt_profile, unique_schema
):
project = ProjectDefinition(models={'my_model.sql': 'select 1 as id'})
querier_ctx = get_querier(
project_def=project,
project_dir=project_root,
profiles_dir=profiles_root,
schema=unique_schema,
test_kwargs={},
)
with querier_ctx as querier:
state_dir = os.path.join(project_root, 'state')
os.makedirs(state_dir)

results = querier.async_wait_for_result(
querier.run()
)
assert len(results['results']) == 1

get_write_manifest(querier, os.path.join(state_dir, 'manifest.json'))

project.models['my_model.sql'] = 'select 2 as id'
project.write_models(project_root, remove=True)
querier.sighup()
assert querier.wait_for_status('ready') is True

results = querier.async_wait_for_result(
querier.run(state='./state', models=['state:modified'])
)
assert len(results['results']) == 1

get_write_manifest(querier, os.path.join(state_dir, 'manifest.json'))

results = querier.async_wait_for_result(
querier.run(state='./state', models=['state:modified']),
)
assert len(results['results']) == 0

project.models['my_model.sql'] = '{% if execute %}{% do exceptions.raise_compiler_error("should not see this") %}{% endif %}select 2 as id'
project.models['my_second_model.sql'] = 'select * from {{ ref("my_model") }}'
project.write_models(project_root, remove=True)
querier.sighup()
assert querier.wait_for_status('ready') is True

# if 'defer' is ignored, this will fail
results = querier.async_wait_for_result(
querier.run(state='./state', models=['my_second_model'], defer=True)
)
assert len(results['results']) == 1
44 changes: 44 additions & 0 deletions test/rpc/test_seed.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
import os
from .util import (
assert_has_threads,
get_querier,
get_write_manifest,
ProjectDefinition,
)

Expand Down Expand Up @@ -65,3 +67,45 @@ def test_rpc_seed_include_exclude(
assert len(results['results']) == 1
results = querier.async_wait_for_result(querier.cli_args('seed --exclude=data_2'))
assert len(results['results']) == 1


@pytest.mark.supported('postgres')
def test_rpc_seed_state(
project_root, profiles_root, dbt_profile, unique_schema
):
project = ProjectDefinition(seeds={'my_seed.csv': 'a,b\n1,hello'})
querier_ctx = get_querier(
project_def=project,
project_dir=project_root,
profiles_dir=profiles_root,
schema=unique_schema,
test_kwargs={},
)
with querier_ctx as querier:
state_dir = os.path.join(project_root, 'state')
os.makedirs(state_dir)

results = querier.async_wait_for_result(
querier.seed()
)
assert len(results['results']) == 1

get_write_manifest(querier, os.path.join(state_dir, 'manifest.json'))

project.seeds['my_seed.csv'] = 'a,b\n1,hello\n2,goodbye'
project.write_seeds(project_root, remove=True)

querier.sighup()
assert querier.wait_for_status('ready') is True

results = querier.async_wait_for_result(
querier.seed(state='./state', select=['state:modified'])
)
assert len(results['results']) == 1

get_write_manifest(querier, os.path.join(state_dir, 'manifest.json'))

results = querier.async_wait_for_result(
querier.seed(state='./state', select=['state:modified']),
)
assert len(results['results']) == 0
Loading

0 comments on commit 1fa149d

Please sign in to comment.