Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auth0/Github auth-provider config validation fix #2009

Merged
merged 12 commits into from
Sep 18, 2023
Merged
31 changes: 23 additions & 8 deletions src/_nebari/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,21 +77,36 @@ def render_config(
] = """Welcome! Learn about Nebari's features and configurations in <a href="https://www.nebari.dev/docs">the documentation</a>. If you have any questions or feedback, reach the team on <a href="https://www.nebari.dev/docs/community#getting-support">Nebari's support forums</a>."""

config["security"]["authentication"] = {"type": auth_provider}

if auth_provider == AuthenticationEnum.github:
if not disable_prompt:
config["security"]["authentication"]["config"] = {
"client_id": input("Github client_id: "),
"client_secret": input("Github client_secret: "),
}
config["security"]["authentication"]["config"] = {
"client_id": os.environ.get(
"GITHUB_CLIENT_ID",
"<enter client id or remove to use GITHUB_CLIENT_ID environment variable (preferred)>",
),
"client_secret": os.environ.get(
"GITHUB_CLIENT_SECRET",
"<enter client secret or remove to use GITHUB_CLIENT_SECRET environment variable (preferred)>",
),
}
elif auth_provider == AuthenticationEnum.auth0:
if auth_auto_provision:
auth0_config = create_client(config.domain, config.project_name)
config["security"]["authentication"]["config"] = auth0_config
else:
config["security"]["authentication"]["config"] = {
"client_id": input("Auth0 client_id: "),
"client_secret": input("Auth0 client_secret: "),
"auth0_subdomain": input("Auth0 subdomain: "),
"client_id": os.environ.get(
"AUTH0_CLIENT_ID",
"<enter client id or remove to use AUTH0_CLIENT_ID environment variable (preferred)>",
),
"client_secret": os.environ.get(
"AUTH0_CLIENT_SECRET",
"<enter client secret or remove to use AUTH0_CLIENT_SECRET environment variable (preferred)>",
),
"auth0_subdomain": os.environ.get(
"AUTH0_DOMAIN",
"<enter subdomain (without .auth0.com) or remove to use AUTH0_DOMAIN environment variable>",
),
}

if cloud_provider == ProviderEnum.do:
Expand Down
60 changes: 53 additions & 7 deletions src/_nebari/stages/kubernetes_keycloak/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import enum
import json
import os
import secrets
import string
import sys
Expand Down Expand Up @@ -61,14 +62,59 @@ def to_yaml(cls, representer, node):


class GitHubConfig(schema.Base):
client_id: str
client_secret: str
client_id: str = pydantic.Field(
default_factory=lambda: os.environ.get("GITHUB_CLIENT_ID")
)
client_secret: str = pydantic.Field(
default_factory=lambda: os.environ.get("GITHUB_CLIENT_SECRET")
)

@pydantic.root_validator(allow_reuse=True)
def validate_required(cls, values):
missing = []
for k, v in {
"client_id": "GITHUB_CLIENT_ID",
"client_secret": "GITHUB_CLIENT_SECRET",
}.items():
if not values.get(k):
missing.append(v)

if len(missing) > 0:
raise ValueError(
f"Missing the following required environment variable(s): {', '.join(missing)}"
)

return values


class Auth0Config(schema.Base):
client_id: str
client_secret: str
auth0_subdomain: str
client_id: str = pydantic.Field(
default_factory=lambda: os.environ.get("AUTH0_CLIENT_ID")
)
client_secret: str = pydantic.Field(
default_factory=lambda: os.environ.get("AUTH0_CLIENT_SECRET")
)
auth0_subdomain: str = pydantic.Field(
default_factory=lambda: os.environ.get("AUTH0_DOMAIN")
)

@pydantic.root_validator(allow_reuse=True)
def validate_required(cls, values):
missing = []
for k, v in {
"client_id": "AUTH0_CLIENT_ID",
"client_secret": "AUTH0_CLIENT_SECRET",
"auth0_subdomain": "AUTH0_DOMAIN",
}.items():
if not values.get(k):
missing.append(v)

if len(missing) > 0:
raise ValueError(
f"Missing the following required environment variable(s): {', '.join(missing)}"
)

return values


class Authentication(schema.Base, ABC):
Expand Down Expand Up @@ -117,12 +163,12 @@ class PasswordAuthentication(Authentication):

class Auth0Authentication(Authentication):
_typ = AuthenticationEnum.auth0
config: Auth0Config
config: Auth0Config = pydantic.Field(default_factory=lambda: Auth0Config())


class GitHubAuthentication(Authentication):
_typ = AuthenticationEnum.github
config: GitHubConfig
config: GitHubConfig = pydantic.Field(default_factory=lambda: GitHubConfig())


class Keycloak(schema.Base):
Expand Down
7 changes: 3 additions & 4 deletions src/_nebari/subcommands/deploy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pathlib
from typing import Optional

import rich
import typer

from _nebari.config import read_configuration
Expand Down Expand Up @@ -65,10 +66,8 @@ def deploy(
from nebari.plugins import nebari_plugin_manager

if dns_provider or dns_auto_provision:
from rich import print

msg = "The [green]`--dns-provider`[/green] and [green]`--dns-auto-provision`[/green] flags have been removed in favor of configuring DNS via nebari-config.yaml"
print(msg)
rich.print(msg)
raise typer.Abort()

stages = nebari_plugin_manager.ordered_stages
Expand All @@ -83,7 +82,7 @@ def deploy(
for stage in stages:
if stage.name == TERRAFORM_STATE_STAGE_NAME:
stages.remove(stage)
print("Skipping remote state provision")
rich.print("Skipping remote state provision")

deploy_configuration(
config,
Expand Down
48 changes: 28 additions & 20 deletions src/_nebari/subcommands/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,18 +213,23 @@ def check_auth_provider_creds(ctx: typer.Context, auth_provider: str):
)
)

os.environ["AUTH0_CLIENT_ID"] = typer.prompt(
"Paste your AUTH0_CLIENT_ID",
hide_input=True,
)
os.environ["AUTH0_CLIENT_SECRET"] = typer.prompt(
"Paste your AUTH0_CLIENT_SECRET",
hide_input=True,
)
os.environ["AUTH0_DOMAIN"] = typer.prompt(
"Paste your AUTH0_DOMAIN",
hide_input=True,
)
if not os.environ.get("AUTH0_CLIENT_ID"):
os.environ["AUTH0_CLIENT_ID"] = typer.prompt(
"Paste your AUTH0_CLIENT_ID",
hide_input=True,
)

if not os.environ.get("AUTH0_CLIENT_SECRET"):
os.environ["AUTH0_CLIENT_SECRET"] = typer.prompt(
"Paste your AUTH0_CLIENT_SECRET",
hide_input=True,
)

if not os.environ.get("AUTH0_DOMAIN"):
os.environ["AUTH0_DOMAIN"] = typer.prompt(
"Paste your AUTH0_DOMAIN",
hide_input=True,
)

# GitHub
elif auth_provider == AuthenticationEnum.github.value.lower() and (
Expand All @@ -237,14 +242,17 @@ def check_auth_provider_creds(ctx: typer.Context, auth_provider: str):
)
)

os.environ["GITHUB_CLIENT_ID"] = typer.prompt(
"Paste your GITHUB_CLIENT_ID",
hide_input=True,
)
os.environ["GITHUB_CLIENT_SECRET"] = typer.prompt(
"Paste your GITHUB_CLIENT_SECRET",
hide_input=True,
)
if not os.environ.get("GITHUB_CLIENT_ID"):
os.environ["GITHUB_CLIENT_ID"] = typer.prompt(
"Paste your GITHUB_CLIENT_ID",
hide_input=True,
)

if not os.environ.get("GITHUB_CLIENT_SECRET"):
os.environ["GITHUB_CLIENT_SECRET"] = typer.prompt(
"Paste your GITHUB_CLIENT_SECRET",
hide_input=True,
)

return auth_provider

Expand Down
9 changes: 9 additions & 0 deletions tests/tests_unit/cli_validate/local.happy.auth0.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
provider: local
project_name: foobar
security:
authentication:
type: Auth0
config:
client_id: test_client
client_secret: test_secret
auth0_subdomain: test_subdomain
8 changes: 8 additions & 0 deletions tests/tests_unit/cli_validate/local.happy.github.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
provider: local
project_name: foobar
security:
authentication:
type: GitHub
config:
client_id: test_client
client_secret: test_secret
4 changes: 1 addition & 3 deletions tests/tests_unit/test_cli_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ def generate_test_data_test_cli_init_happy_path():
for project_name in ["testproject"]:
for domain_name in [f"{project_name}.example.com"]:
for namespace in ["test-ns"]:
for auth_provider in [
"password"
]: # ["password", "Auth0", "GitHub"] # Auth0, Github failing as of 2023-08-23
for auth_provider in ["password", "Auth0", "GitHub"]:
for repository in ["github.com", "gitlab.com"]:
for ci_provider in [
"none",
Expand Down
15 changes: 15 additions & 0 deletions tests/tests_unit/test_cli_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,16 @@ def test_cli_validate_error_from_env(
},
),
("do", {"digital_ocean": {"kubernetes_version": "1.20", "region": "nyc3"}}),
pytest.param(
"local",
{"security": {"authentication": {"type": "Auth0"}}},
id="auth-provider-auth0",
),
pytest.param(
"local",
{"security": {"authentication": {"type": "GitHub"}}},
id="auth-provider-github",
),
],
)
def test_cli_validate_error_missing_cloud_env(
Expand All @@ -216,6 +226,11 @@ def test_cli_validate_error_missing_cloud_env(
"DIGITALOCEAN_TOKEN",
"SPACES_ACCESS_KEY_ID",
"SPACES_SECRET_ACCESS_KEY",
"AUTH0_CLIENT_ID",
"AUTH0_CLIENT_SECRET",
"AUTH0_DOMAIN",
"GITHUB_CLIENT_ID",
"GITHUB_CLIENT_SECRET",
]:
try:
monkeypatch.delenv(e)
Expand Down