diff --git a/src/_nebari/initialize.py b/src/_nebari/initialize.py
index 25e0f2d07..9c963ee70 100644
--- a/src/_nebari/initialize.py
+++ b/src/_nebari/initialize.py
@@ -77,21 +77,36 @@ def render_config(
] = """Welcome! Learn about Nebari's features and configurations in the documentation. If you have any questions or feedback, reach the team on Nebari's support forums."""
config["security"]["authentication"] = {"type": auth_provider}
+
if auth_provider == AuthenticationEnum.github:
- if not disable_prompt:
- config["security"]["authentication"]["config"] = {
- "client_id": input("Github client_id: "),
- "client_secret": input("Github client_secret: "),
- }
+ config["security"]["authentication"]["config"] = {
+ "client_id": os.environ.get(
+ "GITHUB_CLIENT_ID",
+ "",
+ ),
+ "client_secret": os.environ.get(
+ "GITHUB_CLIENT_SECRET",
+ "",
+ ),
+ }
elif auth_provider == AuthenticationEnum.auth0:
if auth_auto_provision:
auth0_config = create_client(config.domain, config.project_name)
config["security"]["authentication"]["config"] = auth0_config
else:
config["security"]["authentication"]["config"] = {
- "client_id": input("Auth0 client_id: "),
- "client_secret": input("Auth0 client_secret: "),
- "auth0_subdomain": input("Auth0 subdomain: "),
+ "client_id": os.environ.get(
+ "AUTH0_CLIENT_ID",
+ "",
+ ),
+ "client_secret": os.environ.get(
+ "AUTH0_CLIENT_SECRET",
+ "",
+ ),
+ "auth0_subdomain": os.environ.get(
+ "AUTH0_DOMAIN",
+ "",
+ ),
}
if cloud_provider == ProviderEnum.do:
diff --git a/src/_nebari/stages/kubernetes_keycloak/__init__.py b/src/_nebari/stages/kubernetes_keycloak/__init__.py
index aea9608d9..767c83189 100644
--- a/src/_nebari/stages/kubernetes_keycloak/__init__.py
+++ b/src/_nebari/stages/kubernetes_keycloak/__init__.py
@@ -1,6 +1,7 @@
import contextlib
import enum
import json
+import os
import secrets
import string
import sys
@@ -61,14 +62,59 @@ def to_yaml(cls, representer, node):
class GitHubConfig(schema.Base):
- client_id: str
- client_secret: str
+ client_id: str = pydantic.Field(
+ default_factory=lambda: os.environ.get("GITHUB_CLIENT_ID")
+ )
+ client_secret: str = pydantic.Field(
+ default_factory=lambda: os.environ.get("GITHUB_CLIENT_SECRET")
+ )
+
+ @pydantic.root_validator(allow_reuse=True)
+ def validate_required(cls, values):
+ missing = []
+ for k, v in {
+ "client_id": "GITHUB_CLIENT_ID",
+ "client_secret": "GITHUB_CLIENT_SECRET",
+ }.items():
+ if not values.get(k):
+ missing.append(v)
+
+ if len(missing) > 0:
+ raise ValueError(
+ f"Missing the following required environment variable(s): {', '.join(missing)}"
+ )
+
+ return values
class Auth0Config(schema.Base):
- client_id: str
- client_secret: str
- auth0_subdomain: str
+ client_id: str = pydantic.Field(
+ default_factory=lambda: os.environ.get("AUTH0_CLIENT_ID")
+ )
+ client_secret: str = pydantic.Field(
+ default_factory=lambda: os.environ.get("AUTH0_CLIENT_SECRET")
+ )
+ auth0_subdomain: str = pydantic.Field(
+ default_factory=lambda: os.environ.get("AUTH0_DOMAIN")
+ )
+
+ @pydantic.root_validator(allow_reuse=True)
+ def validate_required(cls, values):
+ missing = []
+ for k, v in {
+ "client_id": "AUTH0_CLIENT_ID",
+ "client_secret": "AUTH0_CLIENT_SECRET",
+ "auth0_subdomain": "AUTH0_DOMAIN",
+ }.items():
+ if not values.get(k):
+ missing.append(v)
+
+ if len(missing) > 0:
+ raise ValueError(
+ f"Missing the following required environment variable(s): {', '.join(missing)}"
+ )
+
+ return values
class Authentication(schema.Base, ABC):
@@ -117,12 +163,12 @@ class PasswordAuthentication(Authentication):
class Auth0Authentication(Authentication):
_typ = AuthenticationEnum.auth0
- config: Auth0Config
+ config: Auth0Config = pydantic.Field(default_factory=lambda: Auth0Config())
class GitHubAuthentication(Authentication):
_typ = AuthenticationEnum.github
- config: GitHubConfig
+ config: GitHubConfig = pydantic.Field(default_factory=lambda: GitHubConfig())
class Keycloak(schema.Base):
diff --git a/src/_nebari/subcommands/deploy.py b/src/_nebari/subcommands/deploy.py
index 13ee7e8f5..0aa861027 100644
--- a/src/_nebari/subcommands/deploy.py
+++ b/src/_nebari/subcommands/deploy.py
@@ -1,6 +1,7 @@
import pathlib
from typing import Optional
+import rich
import typer
from _nebari.config import read_configuration
@@ -65,10 +66,8 @@ def deploy(
from nebari.plugins import nebari_plugin_manager
if dns_provider or dns_auto_provision:
- from rich import print
-
msg = "The [green]`--dns-provider`[/green] and [green]`--dns-auto-provision`[/green] flags have been removed in favor of configuring DNS via nebari-config.yaml"
- print(msg)
+ rich.print(msg)
raise typer.Abort()
stages = nebari_plugin_manager.ordered_stages
@@ -83,7 +82,7 @@ def deploy(
for stage in stages:
if stage.name == TERRAFORM_STATE_STAGE_NAME:
stages.remove(stage)
- print("Skipping remote state provision")
+ rich.print("Skipping remote state provision")
deploy_configuration(
config,
diff --git a/src/_nebari/subcommands/init.py b/src/_nebari/subcommands/init.py
index 4ce38f21c..ee5d8534e 100644
--- a/src/_nebari/subcommands/init.py
+++ b/src/_nebari/subcommands/init.py
@@ -213,18 +213,23 @@ def check_auth_provider_creds(ctx: typer.Context, auth_provider: str):
)
)
- os.environ["AUTH0_CLIENT_ID"] = typer.prompt(
- "Paste your AUTH0_CLIENT_ID",
- hide_input=True,
- )
- os.environ["AUTH0_CLIENT_SECRET"] = typer.prompt(
- "Paste your AUTH0_CLIENT_SECRET",
- hide_input=True,
- )
- os.environ["AUTH0_DOMAIN"] = typer.prompt(
- "Paste your AUTH0_DOMAIN",
- hide_input=True,
- )
+ if not os.environ.get("AUTH0_CLIENT_ID"):
+ os.environ["AUTH0_CLIENT_ID"] = typer.prompt(
+ "Paste your AUTH0_CLIENT_ID",
+ hide_input=True,
+ )
+
+ if not os.environ.get("AUTH0_CLIENT_SECRET"):
+ os.environ["AUTH0_CLIENT_SECRET"] = typer.prompt(
+ "Paste your AUTH0_CLIENT_SECRET",
+ hide_input=True,
+ )
+
+ if not os.environ.get("AUTH0_DOMAIN"):
+ os.environ["AUTH0_DOMAIN"] = typer.prompt(
+ "Paste your AUTH0_DOMAIN",
+ hide_input=True,
+ )
# GitHub
elif auth_provider == AuthenticationEnum.github.value.lower() and (
@@ -237,14 +242,17 @@ def check_auth_provider_creds(ctx: typer.Context, auth_provider: str):
)
)
- os.environ["GITHUB_CLIENT_ID"] = typer.prompt(
- "Paste your GITHUB_CLIENT_ID",
- hide_input=True,
- )
- os.environ["GITHUB_CLIENT_SECRET"] = typer.prompt(
- "Paste your GITHUB_CLIENT_SECRET",
- hide_input=True,
- )
+ if not os.environ.get("GITHUB_CLIENT_ID"):
+ os.environ["GITHUB_CLIENT_ID"] = typer.prompt(
+ "Paste your GITHUB_CLIENT_ID",
+ hide_input=True,
+ )
+
+ if not os.environ.get("GITHUB_CLIENT_SECRET"):
+ os.environ["GITHUB_CLIENT_SECRET"] = typer.prompt(
+ "Paste your GITHUB_CLIENT_SECRET",
+ hide_input=True,
+ )
return auth_provider
diff --git a/tests/tests_unit/cli_validate/local.happy.auth0.yaml b/tests/tests_unit/cli_validate/local.happy.auth0.yaml
new file mode 100644
index 000000000..d4ba2e18a
--- /dev/null
+++ b/tests/tests_unit/cli_validate/local.happy.auth0.yaml
@@ -0,0 +1,9 @@
+provider: local
+project_name: foobar
+security:
+ authentication:
+ type: Auth0
+ config:
+ client_id: test_client
+ client_secret: test_secret
+ auth0_subdomain: test_subdomain
diff --git a/tests/tests_unit/cli_validate/local.happy.github.yaml b/tests/tests_unit/cli_validate/local.happy.github.yaml
new file mode 100644
index 000000000..30011b469
--- /dev/null
+++ b/tests/tests_unit/cli_validate/local.happy.github.yaml
@@ -0,0 +1,8 @@
+provider: local
+project_name: foobar
+security:
+ authentication:
+ type: GitHub
+ config:
+ client_id: test_client
+ client_secret: test_secret
diff --git a/tests/tests_unit/test_cli_init.py b/tests/tests_unit/test_cli_init.py
index f1d970d22..0e336bf16 100644
--- a/tests/tests_unit/test_cli_init.py
+++ b/tests/tests_unit/test_cli_init.py
@@ -73,9 +73,7 @@ def generate_test_data_test_cli_init_happy_path():
for project_name in ["testproject"]:
for domain_name in [f"{project_name}.example.com"]:
for namespace in ["test-ns"]:
- for auth_provider in [
- "password"
- ]: # ["password", "Auth0", "GitHub"] # Auth0, Github failing as of 2023-08-23
+ for auth_provider in ["password", "Auth0", "GitHub"]:
for repository in ["github.com", "gitlab.com"]:
for ci_provider in [
"none",
diff --git a/tests/tests_unit/test_cli_validate.py b/tests/tests_unit/test_cli_validate.py
index 44b7ce0f0..13955c1fc 100644
--- a/tests/tests_unit/test_cli_validate.py
+++ b/tests/tests_unit/test_cli_validate.py
@@ -197,6 +197,16 @@ def test_cli_validate_error_from_env(
},
),
("do", {"digital_ocean": {"kubernetes_version": "1.20", "region": "nyc3"}}),
+ pytest.param(
+ "local",
+ {"security": {"authentication": {"type": "Auth0"}}},
+ id="auth-provider-auth0",
+ ),
+ pytest.param(
+ "local",
+ {"security": {"authentication": {"type": "GitHub"}}},
+ id="auth-provider-github",
+ ),
],
)
def test_cli_validate_error_missing_cloud_env(
@@ -216,6 +226,11 @@ def test_cli_validate_error_missing_cloud_env(
"DIGITALOCEAN_TOKEN",
"SPACES_ACCESS_KEY_ID",
"SPACES_SECRET_ACCESS_KEY",
+ "AUTH0_CLIENT_ID",
+ "AUTH0_CLIENT_SECRET",
+ "AUTH0_DOMAIN",
+ "GITHUB_CLIENT_ID",
+ "GITHUB_CLIENT_SECRET",
]:
try:
monkeypatch.delenv(e)