Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: fix validation error related to provider #2054 #2056

Merged
merged 15 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/_nebari/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def render_config(
ssl_cert_email: str = None,
):
config = {
"provider": cloud_provider,
"provider": cloud_provider.value,
"namespace": namespace,
"nebari_version": __version__,
}
Expand Down
99 changes: 52 additions & 47 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,29 @@ class ExistingProvider(schema.Base):
}


provider_enum_model_map = {
schema.ProviderEnum.local: LocalProvider,
schema.ProviderEnum.existing: ExistingProvider,
schema.ProviderEnum.gcp: GoogleCloudPlatformProvider,
schema.ProviderEnum.aws: AmazonWebServicesProvider,
schema.ProviderEnum.azure: AzureProvider,
schema.ProviderEnum.do: DigitalOceanProvider,
}

provider_enum_name_map: Dict[schema.ProviderEnum, str] = {
schema.ProviderEnum.local: "local",
schema.ProviderEnum.existing: "existing",
schema.ProviderEnum.gcp: "google_cloud_platform",
schema.ProviderEnum.aws: "amazon_web_services",
schema.ProviderEnum.azure: "azure",
schema.ProviderEnum.do: "digital_ocean",
}

provider_name_abbreviation_map: Dict[str, str] = {
value: key.value for key, value in provider_enum_name_map.items()
}


class InputSchema(schema.Base):
local: typing.Optional[LocalProvider]
existing: typing.Optional[ExistingProvider]
Expand All @@ -512,54 +535,36 @@ class InputSchema(schema.Base):
azure: typing.Optional[AzureProvider]
digital_ocean: typing.Optional[DigitalOceanProvider]

@pydantic.root_validator
@pydantic.root_validator(pre=True)
def check_provider(cls, values):
if (
values["provider"] == schema.ProviderEnum.local
and values.get("local") is None
):
values["local"] = LocalProvider()
elif (
values["provider"] == schema.ProviderEnum.existing
and values.get("existing") is None
):
values["existing"] = ExistingProvider()
elif (
values["provider"] == schema.ProviderEnum.gcp
and values.get("google_cloud_platform") is None
):
values["google_cloud_platform"] = GoogleCloudPlatformProvider()
elif (
values["provider"] == schema.ProviderEnum.aws
and values.get("amazon_web_services") is None
):
values["amazon_web_services"] = AmazonWebServicesProvider()
elif (
values["provider"] == schema.ProviderEnum.azure
and values.get("azure") is None
):
values["azure"] = AzureProvider()
elif (
values["provider"] == schema.ProviderEnum.do
and values.get("digital_ocean") is None
):
values["digital_ocean"] = DigitalOceanProvider()

if (
sum(
(_ in values and values[_] is not None)
for _ in {
"local",
"existing",
"google_cloud_platform",
"amazon_web_services",
"azure",
"digital_ocean",
}
)
!= 1
):
raise ValueError("multiple providers set or wrong provider fields set")
if "provider" in values:
provider: str = values["provider"]
if hasattr(schema.ProviderEnum, provider):
# TODO: all cloud providers has required fields, but local and existing don't.
# And there is no way to initialize a model without user input here.
# We preserve the original behavior here, but we should find a better way to do this.
if provider in ["local", "existing"]:
values[provider] = provider_enum_model_map[provider]()
else:
# if the provider field is invalid, it won't be set when this validator is called
# so we need to check for it explicitly here, and set the `pre` to True
# TODO: this is a workaround, check if there is a better way to do this in Pydantic v2
fangchenli marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"'{provider}' is not a valid enumeration member; permitted: local, existing, do, aws, gcp, azure"
)
else:
setted_providers = [
provider
for provider in provider_name_abbreviation_map.keys()
if provider in values
]
num_providers = len(setted_providers)
if num_providers > 1:
raise ValueError(f"Multiple providers set: {setted_providers}")
elif num_providers == 1:
values["provider"] = provider_name_abbreviation_map[setted_providers[0]]
elif num_providers == 0:
values["provider"] = schema.ProviderEnum.local.value
return values


Expand Down
5 changes: 5 additions & 0 deletions tests/tests_unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,8 @@ def nebari_render(nebari_config, nebari_stages, tmp_path):
write_configuration(config_filename, nebari_config)
render_template(tmp_path, nebari_config, nebari_stages)
return tmp_path, config_filename


@pytest.fixture
def config_schema():
return nebari_plugin_manager.config_schema
25 changes: 0 additions & 25 deletions tests/tests_unit/test_cli_upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,31 +233,6 @@ def test_cli_upgrade_fail_on_missing_file():
)


def test_cli_upgrade_fail_invalid_file():
with tempfile.TemporaryDirectory() as tmp:
tmp_file = Path(tmp).resolve() / "nebari-config.yaml"
assert tmp_file.exists() is False

nebari_config = yaml.safe_load(
"""
project_name: test
provider: fake
"""
)

with open(tmp_file.resolve(), "w") as f:
yaml.dump(nebari_config, f)

assert tmp_file.exists() is True
app = create_cli()

result = runner.invoke(app, ["upgrade", "--config", tmp_file.resolve()])

assert 1 == result.exit_code
assert result.exception
assert "provider" in str(result.exception)

Comment on lines -258 to -259
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this test here because it's not asserting the expected behavior, also the exception was raised before any upgrade steps, so it's not related to upgrade.


def test_cli_upgrade_fail_on_downgrade():
start_version = "9999.9.9" # way in the future
end_version = _nebari.upgrade.__version__
Expand Down
89 changes: 89 additions & 0 deletions tests/tests_unit/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from contextlib import nullcontext

import pytest
from pydantic.error_wrappers import ValidationError

from nebari import schema
from nebari.plugins import nebari_plugin_manager

Expand Down Expand Up @@ -48,3 +53,87 @@ def test_render_schema(nebari_config):
assert isinstance(nebari_config, schema.Main)
assert nebari_config.project_name == f"pytest{nebari_config.provider.value}"
assert nebari_config.namespace == "dev"


@pytest.mark.parametrize(
"provider, exception",
[
(
"fake",
pytest.raises(
ValueError,
match="'fake' is not a valid enumeration member; permitted: local, existing, do, aws, gcp, azure",
),
),
("aws", nullcontext()),
("gcp", nullcontext()),
("do", nullcontext()),
("azure", nullcontext()),
("existing", nullcontext()),
("local", nullcontext()),
],
)
def test_provider_validation(config_schema, provider, exception):
config_dict = {
"project_name": "test",
"provider": f"{provider}",
}
with exception:
config = config_schema(**config_dict)
assert config.provider == provider


@pytest.mark.parametrize(
"provider, full_name, default_fields",
[
("local", "local", {}),
("existing", "existing", {}),
(
"aws",
"amazon_web_services",
{"region": "us-east-1", "kubernetes_version": "1.18"},
),
(
"gcp",
"google_cloud_platform",
{
"region": "us-east1",
"project": "test-project",
"kubernetes_version": "1.18",
},
),
(
"do",
"digital_ocean",
{"region": "nyc3", "kubernetes_version": "1.19.2-do.3"},
),
(
"azure",
"azure",
{
"region": "eastus",
"kubernetes_version": "1.18",
"storage_account_postfix": "test",
},
),
],
)
def test_no_provider(config_schema, provider, full_name, default_fields):
config_dict = {
"project_name": "test",
f"{full_name}": default_fields,
}
config = config_schema(**config_dict)
assert config.provider == provider
assert full_name in config.dict()


def test_multiple_providers(config_schema):
config_dict = {
"project_name": "test",
"local": {},
"existing": {},
}
msg = r"Multiple providers set: \['local', 'existing'\]"
with pytest.raises(ValidationError, match=msg):
config_schema(**config_dict)