diff --git a/src/_nebari/initialize.py b/src/_nebari/initialize.py index 25e0f2d07..9c963ee70 100644 --- a/src/_nebari/initialize.py +++ b/src/_nebari/initialize.py @@ -77,21 +77,36 @@ def render_config( ] = """Welcome! Learn about Nebari's features and configurations in the documentation. If you have any questions or feedback, reach the team on Nebari's support forums.""" config["security"]["authentication"] = {"type": auth_provider} + if auth_provider == AuthenticationEnum.github: - if not disable_prompt: - config["security"]["authentication"]["config"] = { - "client_id": input("Github client_id: "), - "client_secret": input("Github client_secret: "), - } + config["security"]["authentication"]["config"] = { + "client_id": os.environ.get( + "GITHUB_CLIENT_ID", + "", + ), + "client_secret": os.environ.get( + "GITHUB_CLIENT_SECRET", + "", + ), + } elif auth_provider == AuthenticationEnum.auth0: if auth_auto_provision: auth0_config = create_client(config.domain, config.project_name) config["security"]["authentication"]["config"] = auth0_config else: config["security"]["authentication"]["config"] = { - "client_id": input("Auth0 client_id: "), - "client_secret": input("Auth0 client_secret: "), - "auth0_subdomain": input("Auth0 subdomain: "), + "client_id": os.environ.get( + "AUTH0_CLIENT_ID", + "", + ), + "client_secret": os.environ.get( + "AUTH0_CLIENT_SECRET", + "", + ), + "auth0_subdomain": os.environ.get( + "AUTH0_DOMAIN", + "", + ), } if cloud_provider == ProviderEnum.do: diff --git a/src/_nebari/stages/kubernetes_keycloak/__init__.py b/src/_nebari/stages/kubernetes_keycloak/__init__.py index aea9608d9..767c83189 100644 --- a/src/_nebari/stages/kubernetes_keycloak/__init__.py +++ b/src/_nebari/stages/kubernetes_keycloak/__init__.py @@ -1,6 +1,7 @@ import contextlib import enum import json +import os import secrets import string import sys @@ -61,14 +62,59 @@ def to_yaml(cls, representer, node): class GitHubConfig(schema.Base): - client_id: str - client_secret: str + client_id: str = pydantic.Field( + default_factory=lambda: os.environ.get("GITHUB_CLIENT_ID") + ) + client_secret: str = pydantic.Field( + default_factory=lambda: os.environ.get("GITHUB_CLIENT_SECRET") + ) + + @pydantic.root_validator(allow_reuse=True) + def validate_required(cls, values): + missing = [] + for k, v in { + "client_id": "GITHUB_CLIENT_ID", + "client_secret": "GITHUB_CLIENT_SECRET", + }.items(): + if not values.get(k): + missing.append(v) + + if len(missing) > 0: + raise ValueError( + f"Missing the following required environment variable(s): {', '.join(missing)}" + ) + + return values class Auth0Config(schema.Base): - client_id: str - client_secret: str - auth0_subdomain: str + client_id: str = pydantic.Field( + default_factory=lambda: os.environ.get("AUTH0_CLIENT_ID") + ) + client_secret: str = pydantic.Field( + default_factory=lambda: os.environ.get("AUTH0_CLIENT_SECRET") + ) + auth0_subdomain: str = pydantic.Field( + default_factory=lambda: os.environ.get("AUTH0_DOMAIN") + ) + + @pydantic.root_validator(allow_reuse=True) + def validate_required(cls, values): + missing = [] + for k, v in { + "client_id": "AUTH0_CLIENT_ID", + "client_secret": "AUTH0_CLIENT_SECRET", + "auth0_subdomain": "AUTH0_DOMAIN", + }.items(): + if not values.get(k): + missing.append(v) + + if len(missing) > 0: + raise ValueError( + f"Missing the following required environment variable(s): {', '.join(missing)}" + ) + + return values class Authentication(schema.Base, ABC): @@ -117,12 +163,12 @@ class PasswordAuthentication(Authentication): class Auth0Authentication(Authentication): _typ = AuthenticationEnum.auth0 - config: Auth0Config + config: Auth0Config = pydantic.Field(default_factory=lambda: Auth0Config()) class GitHubAuthentication(Authentication): _typ = AuthenticationEnum.github - config: GitHubConfig + config: GitHubConfig = pydantic.Field(default_factory=lambda: GitHubConfig()) class Keycloak(schema.Base): diff --git a/src/_nebari/subcommands/deploy.py b/src/_nebari/subcommands/deploy.py index 13ee7e8f5..0aa861027 100644 --- a/src/_nebari/subcommands/deploy.py +++ b/src/_nebari/subcommands/deploy.py @@ -1,6 +1,7 @@ import pathlib from typing import Optional +import rich import typer from _nebari.config import read_configuration @@ -65,10 +66,8 @@ def deploy( from nebari.plugins import nebari_plugin_manager if dns_provider or dns_auto_provision: - from rich import print - msg = "The [green]`--dns-provider`[/green] and [green]`--dns-auto-provision`[/green] flags have been removed in favor of configuring DNS via nebari-config.yaml" - print(msg) + rich.print(msg) raise typer.Abort() stages = nebari_plugin_manager.ordered_stages @@ -83,7 +82,7 @@ def deploy( for stage in stages: if stage.name == TERRAFORM_STATE_STAGE_NAME: stages.remove(stage) - print("Skipping remote state provision") + rich.print("Skipping remote state provision") deploy_configuration( config, diff --git a/src/_nebari/subcommands/init.py b/src/_nebari/subcommands/init.py index 4ce38f21c..ee5d8534e 100644 --- a/src/_nebari/subcommands/init.py +++ b/src/_nebari/subcommands/init.py @@ -213,18 +213,23 @@ def check_auth_provider_creds(ctx: typer.Context, auth_provider: str): ) ) - os.environ["AUTH0_CLIENT_ID"] = typer.prompt( - "Paste your AUTH0_CLIENT_ID", - hide_input=True, - ) - os.environ["AUTH0_CLIENT_SECRET"] = typer.prompt( - "Paste your AUTH0_CLIENT_SECRET", - hide_input=True, - ) - os.environ["AUTH0_DOMAIN"] = typer.prompt( - "Paste your AUTH0_DOMAIN", - hide_input=True, - ) + if not os.environ.get("AUTH0_CLIENT_ID"): + os.environ["AUTH0_CLIENT_ID"] = typer.prompt( + "Paste your AUTH0_CLIENT_ID", + hide_input=True, + ) + + if not os.environ.get("AUTH0_CLIENT_SECRET"): + os.environ["AUTH0_CLIENT_SECRET"] = typer.prompt( + "Paste your AUTH0_CLIENT_SECRET", + hide_input=True, + ) + + if not os.environ.get("AUTH0_DOMAIN"): + os.environ["AUTH0_DOMAIN"] = typer.prompt( + "Paste your AUTH0_DOMAIN", + hide_input=True, + ) # GitHub elif auth_provider == AuthenticationEnum.github.value.lower() and ( @@ -237,14 +242,17 @@ def check_auth_provider_creds(ctx: typer.Context, auth_provider: str): ) ) - os.environ["GITHUB_CLIENT_ID"] = typer.prompt( - "Paste your GITHUB_CLIENT_ID", - hide_input=True, - ) - os.environ["GITHUB_CLIENT_SECRET"] = typer.prompt( - "Paste your GITHUB_CLIENT_SECRET", - hide_input=True, - ) + if not os.environ.get("GITHUB_CLIENT_ID"): + os.environ["GITHUB_CLIENT_ID"] = typer.prompt( + "Paste your GITHUB_CLIENT_ID", + hide_input=True, + ) + + if not os.environ.get("GITHUB_CLIENT_SECRET"): + os.environ["GITHUB_CLIENT_SECRET"] = typer.prompt( + "Paste your GITHUB_CLIENT_SECRET", + hide_input=True, + ) return auth_provider diff --git a/tests/tests_unit/cli_validate/local.happy.auth0.yaml b/tests/tests_unit/cli_validate/local.happy.auth0.yaml new file mode 100644 index 000000000..d4ba2e18a --- /dev/null +++ b/tests/tests_unit/cli_validate/local.happy.auth0.yaml @@ -0,0 +1,9 @@ +provider: local +project_name: foobar +security: + authentication: + type: Auth0 + config: + client_id: test_client + client_secret: test_secret + auth0_subdomain: test_subdomain diff --git a/tests/tests_unit/cli_validate/local.happy.github.yaml b/tests/tests_unit/cli_validate/local.happy.github.yaml new file mode 100644 index 000000000..30011b469 --- /dev/null +++ b/tests/tests_unit/cli_validate/local.happy.github.yaml @@ -0,0 +1,8 @@ +provider: local +project_name: foobar +security: + authentication: + type: GitHub + config: + client_id: test_client + client_secret: test_secret diff --git a/tests/tests_unit/test_cli_init.py b/tests/tests_unit/test_cli_init.py index f1d970d22..0e336bf16 100644 --- a/tests/tests_unit/test_cli_init.py +++ b/tests/tests_unit/test_cli_init.py @@ -73,9 +73,7 @@ def generate_test_data_test_cli_init_happy_path(): for project_name in ["testproject"]: for domain_name in [f"{project_name}.example.com"]: for namespace in ["test-ns"]: - for auth_provider in [ - "password" - ]: # ["password", "Auth0", "GitHub"] # Auth0, Github failing as of 2023-08-23 + for auth_provider in ["password", "Auth0", "GitHub"]: for repository in ["github.com", "gitlab.com"]: for ci_provider in [ "none", diff --git a/tests/tests_unit/test_cli_validate.py b/tests/tests_unit/test_cli_validate.py index 44b7ce0f0..13955c1fc 100644 --- a/tests/tests_unit/test_cli_validate.py +++ b/tests/tests_unit/test_cli_validate.py @@ -197,6 +197,16 @@ def test_cli_validate_error_from_env( }, ), ("do", {"digital_ocean": {"kubernetes_version": "1.20", "region": "nyc3"}}), + pytest.param( + "local", + {"security": {"authentication": {"type": "Auth0"}}}, + id="auth-provider-auth0", + ), + pytest.param( + "local", + {"security": {"authentication": {"type": "GitHub"}}}, + id="auth-provider-github", + ), ], ) def test_cli_validate_error_missing_cloud_env( @@ -216,6 +226,11 @@ def test_cli_validate_error_missing_cloud_env( "DIGITALOCEAN_TOKEN", "SPACES_ACCESS_KEY_ID", "SPACES_SECRET_ACCESS_KEY", + "AUTH0_CLIENT_ID", + "AUTH0_CLIENT_SECRET", + "AUTH0_DOMAIN", + "GITHUB_CLIENT_ID", + "GITHUB_CLIENT_SECRET", ]: try: monkeypatch.delenv(e)