Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add --project-dir flag to allow specifying project directory #1549

Merged
merged 3 commits into from
Jun 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions core/dbt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,16 @@ def run_from_args(parsed):
def _build_base_subparser():
base_subparser = argparse.ArgumentParser(add_help=False)

base_subparser.add_argument(
'--project-dir',
default=None,
type=str,
help="""
Which directory to look in for the dbt_project.yml file.
Default is the current working directory and its parents.
"""
)

base_subparser.add_argument(
'--profiles-dir',
default=PROFILES_DIR,
Expand Down
31 changes: 20 additions & 11 deletions core/dbt/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,19 @@ def interpret_results(self, results):
return True


def get_nearest_project_dir():
def get_nearest_project_dir(args):
# If the user provides an explicit project directory, use that
# but don't look at parent directories.
if args.project_dir:
project_file = os.path.join(args.project_dir, "dbt_project.yml")
if os.path.exists(project_file):
return args.project_dir
else:
raise dbt.exceptions.RuntimeException(
"fatal: Invalid --project-dir flag. Not a dbt project. "
"Missing dbt_project.yml file"
)

root_path = os.path.abspath(os.sep)
cwd = os.getcwd()

Expand All @@ -102,24 +114,21 @@ def get_nearest_project_dir():
return cwd
cwd = os.path.dirname(cwd)

return None

raise dbt.exceptions.RuntimeException(
"fatal: Not a dbt project (or any of the parent directories). "
"Missing dbt_project.yml file"
)

def move_to_nearest_project_dir():
nearest_project_dir = get_nearest_project_dir()
if nearest_project_dir is None:
raise dbt.exceptions.RuntimeException(
"fatal: Not a dbt project (or any of the parent directories). "
"Missing dbt_project.yml file"
)

def move_to_nearest_project_dir(args):
nearest_project_dir = get_nearest_project_dir(args)
os.chdir(nearest_project_dir)


class RequiresProjectTask(BaseTask):
@classmethod
def from_args(cls, args):
move_to_nearest_project_dir()
move_to_nearest_project_dir(args)
return super(RequiresProjectTask, cls).from_args(args)


Expand Down
43 changes: 37 additions & 6 deletions test/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from dbt.adapters.redshift import RedshiftCredentials
from dbt.contracts.project import PackageConfig
from dbt.semver import VersionSpecifier
from dbt.task.run_operation import RunOperationTask


INITIAL_ROOT = os.getcwd()


@contextmanager
Expand Down Expand Up @@ -65,7 +69,7 @@ def temp_cd(path):

class Args(object):
def __init__(self, profiles_dir=None, threads=None, profile=None,
cli_vars=None, version_check=None):
cli_vars=None, version_check=None, project_dir=None):
self.profile = profile
if threads is not None:
self.threads = threads
Expand All @@ -75,11 +79,13 @@ def __init__(self, profiles_dir=None, threads=None, profile=None,
self.vars = cli_vars
if version_check is not None:
self.version_check = version_check
if project_dir is not None:
self.project_dir = project_dir


class BaseConfigTest(unittest.TestCase):
"""Subclass this, and before calling the superclass setUp, set
profiles_dir.
self.profiles_dir and self.project_dir.
"""
def setUp(self):
self.default_project_data = {
Expand Down Expand Up @@ -147,7 +153,7 @@ def setUp(self):
}
}
self.args = Args(profiles_dir=self.profiles_dir, cli_vars='{}',
version_check=True)
version_check=True, project_dir=self.project_dir)
self.env_override = {
'env_value_type': 'postgres',
'env_value_host': 'env-postgres-host',
Expand Down Expand Up @@ -176,7 +182,7 @@ def tearDown(self):
except EnvironmentError:
pass

def proejct_path(self, name):
def project_path(self, name):
return os.path.join(self.project_dir, name)

def profile_path(self, name):
Expand All @@ -185,11 +191,11 @@ def profile_path(self, name):
def write_project(self, project_data=None):
if project_data is None:
project_data = self.project_data
with open(self.proejct_path('dbt_project.yml'), 'w') as fp:
with open(self.project_path('dbt_project.yml'), 'w') as fp:
yaml.dump(project_data, fp)

def write_packages(self, package_data):
with open(self.proejct_path('packages.yml'), 'w') as fp:
with open(self.project_path('packages.yml'), 'w') as fp:
yaml.dump(package_data, fp)

def write_profile(self, profile_data=None):
Expand All @@ -202,6 +208,7 @@ def write_profile(self, profile_data=None):
class TestProfile(BaseConfigTest):
def setUp(self):
self.profiles_dir = '/invalid-path'
self.project_dir = '/invalid-project-path'
super(TestProfile, self).setUp()

def from_raw_profiles(self):
Expand Down Expand Up @@ -928,6 +935,30 @@ def test_with_invalid_package(self):
dbt.config.Project.from_project_root(self.project_dir, {})


class TestRunOperationTask(BaseFileTest):
def setUp(self):
super(TestRunOperationTask, self).setUp()
self.write_project(self.default_project_data)
self.write_profile(self.default_profile_data)

def tearDown(self):
super(TestRunOperationTask, self).tearDown()
# These tests will change the directory to the project path,
# so it's necessary to change it back at the end.
os.chdir(INITIAL_ROOT)

def test_run_operation_task(self):
self.assertEqual(os.getcwd(), INITIAL_ROOT)
self.assertNotEqual(INITIAL_ROOT, self.project_dir)
new_task = RunOperationTask.from_args(self.args)
self.assertEqual(os.getcwd(), self.project_dir)

def test_run_operation_task_with_bad_path(self):
self.args.project_dir = 'bad_path'
with self.assertRaises(dbt.exceptions.RuntimeException):
new_task = RunOperationTask.from_args(self.args)


class TestVariableProjectFile(BaseFileTest):
def setUp(self):
super(TestVariableProjectFile, self).setUp()
Expand Down