Skip to content

Commit

Permalink
BUG: fix validation error related to provider #2054 (#2056)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
fangchenli and pre-commit-ci[bot] authored Oct 16, 2023
1 parent 67e967f commit 4db8400
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 73 deletions.
2 changes: 1 addition & 1 deletion src/_nebari/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def render_config(
ssl_cert_email: str = None,
):
config = {
"provider": cloud_provider,
"provider": cloud_provider.value,
"namespace": namespace,
"nebari_version": __version__,
}
Expand Down
99 changes: 52 additions & 47 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,29 @@ class ExistingProvider(schema.Base):
}


provider_enum_model_map = {
schema.ProviderEnum.local: LocalProvider,
schema.ProviderEnum.existing: ExistingProvider,
schema.ProviderEnum.gcp: GoogleCloudPlatformProvider,
schema.ProviderEnum.aws: AmazonWebServicesProvider,
schema.ProviderEnum.azure: AzureProvider,
schema.ProviderEnum.do: DigitalOceanProvider,
}

provider_enum_name_map: Dict[schema.ProviderEnum, str] = {
schema.ProviderEnum.local: "local",
schema.ProviderEnum.existing: "existing",
schema.ProviderEnum.gcp: "google_cloud_platform",
schema.ProviderEnum.aws: "amazon_web_services",
schema.ProviderEnum.azure: "azure",
schema.ProviderEnum.do: "digital_ocean",
}

provider_name_abbreviation_map: Dict[str, str] = {
value: key.value for key, value in provider_enum_name_map.items()
}


class InputSchema(schema.Base):
local: typing.Optional[LocalProvider]
existing: typing.Optional[ExistingProvider]
Expand All @@ -512,54 +535,36 @@ class InputSchema(schema.Base):
azure: typing.Optional[AzureProvider]
digital_ocean: typing.Optional[DigitalOceanProvider]

@pydantic.root_validator
@pydantic.root_validator(pre=True)
def check_provider(cls, values):
if (
values["provider"] == schema.ProviderEnum.local
and values.get("local") is None
):
values["local"] = LocalProvider()
elif (
values["provider"] == schema.ProviderEnum.existing
and values.get("existing") is None
):
values["existing"] = ExistingProvider()
elif (
values["provider"] == schema.ProviderEnum.gcp
and values.get("google_cloud_platform") is None
):
values["google_cloud_platform"] = GoogleCloudPlatformProvider()
elif (
values["provider"] == schema.ProviderEnum.aws
and values.get("amazon_web_services") is None
):
values["amazon_web_services"] = AmazonWebServicesProvider()
elif (
values["provider"] == schema.ProviderEnum.azure
and values.get("azure") is None
):
values["azure"] = AzureProvider()
elif (
values["provider"] == schema.ProviderEnum.do
and values.get("digital_ocean") is None
):
values["digital_ocean"] = DigitalOceanProvider()

if (
sum(
(_ in values and values[_] is not None)
for _ in {
"local",
"existing",
"google_cloud_platform",
"amazon_web_services",
"azure",
"digital_ocean",
}
)
!= 1
):
raise ValueError("multiple providers set or wrong provider fields set")
if "provider" in values:
provider: str = values["provider"]
if hasattr(schema.ProviderEnum, provider):
# TODO: all cloud providers has required fields, but local and existing don't.
# And there is no way to initialize a model without user input here.
# We preserve the original behavior here, but we should find a better way to do this.
if provider in ["local", "existing"]:
values[provider] = provider_enum_model_map[provider]()
else:
# if the provider field is invalid, it won't be set when this validator is called
# so we need to check for it explicitly here, and set the `pre` to True
# TODO: this is a workaround, check if there is a better way to do this in Pydantic v2
raise ValueError(
f"'{provider}' is not a valid enumeration member; permitted: local, existing, do, aws, gcp, azure"
)
else:
setted_providers = [
provider
for provider in provider_name_abbreviation_map.keys()
if provider in values
]
num_providers = len(setted_providers)
if num_providers > 1:
raise ValueError(f"Multiple providers set: {setted_providers}")
elif num_providers == 1:
values["provider"] = provider_name_abbreviation_map[setted_providers[0]]
elif num_providers == 0:
values["provider"] = schema.ProviderEnum.local.value
return values


Expand Down
5 changes: 5 additions & 0 deletions tests/tests_unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,8 @@ def nebari_render(nebari_config, nebari_stages, tmp_path):
write_configuration(config_filename, nebari_config)
render_template(tmp_path, nebari_config, nebari_stages)
return tmp_path, config_filename


@pytest.fixture
def config_schema():
return nebari_plugin_manager.config_schema
25 changes: 0 additions & 25 deletions tests/tests_unit/test_cli_upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,31 +233,6 @@ def test_cli_upgrade_fail_on_missing_file():
)


def test_cli_upgrade_fail_invalid_file():
with tempfile.TemporaryDirectory() as tmp:
tmp_file = Path(tmp).resolve() / "nebari-config.yaml"
assert tmp_file.exists() is False

nebari_config = yaml.safe_load(
"""
project_name: test
provider: fake
"""
)

with open(tmp_file.resolve(), "w") as f:
yaml.dump(nebari_config, f)

assert tmp_file.exists() is True
app = create_cli()

result = runner.invoke(app, ["upgrade", "--config", tmp_file.resolve()])

assert 1 == result.exit_code
assert result.exception
assert "provider" in str(result.exception)


def test_cli_upgrade_fail_on_downgrade():
start_version = "9999.9.9" # way in the future
end_version = _nebari.upgrade.__version__
Expand Down
89 changes: 89 additions & 0 deletions tests/tests_unit/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from contextlib import nullcontext

import pytest
from pydantic.error_wrappers import ValidationError

from nebari import schema
from nebari.plugins import nebari_plugin_manager

Expand Down Expand Up @@ -48,3 +53,87 @@ def test_render_schema(nebari_config):
assert isinstance(nebari_config, schema.Main)
assert nebari_config.project_name == f"pytest{nebari_config.provider.value}"
assert nebari_config.namespace == "dev"


@pytest.mark.parametrize(
"provider, exception",
[
(
"fake",
pytest.raises(
ValueError,
match="'fake' is not a valid enumeration member; permitted: local, existing, do, aws, gcp, azure",
),
),
("aws", nullcontext()),
("gcp", nullcontext()),
("do", nullcontext()),
("azure", nullcontext()),
("existing", nullcontext()),
("local", nullcontext()),
],
)
def test_provider_validation(config_schema, provider, exception):
config_dict = {
"project_name": "test",
"provider": f"{provider}",
}
with exception:
config = config_schema(**config_dict)
assert config.provider == provider


@pytest.mark.parametrize(
"provider, full_name, default_fields",
[
("local", "local", {}),
("existing", "existing", {}),
(
"aws",
"amazon_web_services",
{"region": "us-east-1", "kubernetes_version": "1.18"},
),
(
"gcp",
"google_cloud_platform",
{
"region": "us-east1",
"project": "test-project",
"kubernetes_version": "1.18",
},
),
(
"do",
"digital_ocean",
{"region": "nyc3", "kubernetes_version": "1.19.2-do.3"},
),
(
"azure",
"azure",
{
"region": "eastus",
"kubernetes_version": "1.18",
"storage_account_postfix": "test",
},
),
],
)
def test_no_provider(config_schema, provider, full_name, default_fields):
config_dict = {
"project_name": "test",
f"{full_name}": default_fields,
}
config = config_schema(**config_dict)
assert config.provider == provider
assert full_name in config.dict()


def test_multiple_providers(config_schema):
config_dict = {
"project_name": "test",
"local": {},
"existing": {},
}
msg = r"Multiple providers set: \['local', 'existing'\]"
with pytest.raises(ValidationError, match=msg):
config_schema(**config_dict)

0 comments on commit 4db8400

Please sign in to comment.