Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: fix validation error related to provider #2054 #2056

Merged
merged 15 commits into from
Oct 16, 2023
Merged
11 changes: 10 additions & 1 deletion src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,17 @@ class InputSchema(schema.Base):
azure: typing.Optional[AzureProvider]
digital_ocean: typing.Optional[DigitalOceanProvider]

@pydantic.root_validator
@pydantic.root_validator(pre=True)
def check_provider(cls, values):
# if the provider field is invalid, it won't be set when this validator is called
# so we need to check for it explicitly here, and set the `pre` to True
# TODO: this is a workaround, check if there is a better way to do this in Pydantic v2
iameskild marked this conversation as resolved.
Show resolved Hide resolved
# TODO: all the cloud providers are initialized without required fields, so they are not working here
if not hasattr(schema.ProviderEnum, values["provider"]):
provider = values["provider"]
msg = f"'{provider}' is not a valid enumeration member; permitted: {', '.join(schema.ProviderEnum.__members__.keys())}"
raise ValueError(msg)

if (
values["provider"] == schema.ProviderEnum.local
and values.get("local") is None
Expand Down
5 changes: 5 additions & 0 deletions tests/tests_unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,8 @@ def nebari_render(nebari_config, nebari_stages, tmp_path):
write_configuration(config_filename, nebari_config)
render_template(tmp_path, nebari_config, nebari_stages)
return tmp_path, config_filename


@pytest.fixture
def config_schema():
return nebari_plugin_manager.config_schema
25 changes: 0 additions & 25 deletions tests/tests_unit/test_cli_upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,31 +233,6 @@ def test_cli_upgrade_fail_on_missing_file():
)


def test_cli_upgrade_fail_invalid_file():
with tempfile.TemporaryDirectory() as tmp:
tmp_file = Path(tmp).resolve() / "nebari-config.yaml"
assert tmp_file.exists() is False

nebari_config = yaml.safe_load(
"""
project_name: test
provider: fake
"""
)

with open(tmp_file.resolve(), "w") as f:
yaml.dump(nebari_config, f)

assert tmp_file.exists() is True
app = create_cli()

result = runner.invoke(app, ["upgrade", "--config", tmp_file.resolve()])

assert 1 == result.exit_code
assert result.exception
assert "provider" in str(result.exception)

Comment on lines -258 to -259
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this test here because it's not asserting the expected behavior, also the exception was raised before any upgrade steps, so it's not related to upgrade.


def test_cli_upgrade_fail_on_downgrade():
start_version = "9999.9.9" # way in the future
end_version = _nebari.upgrade.__version__
Expand Down
35 changes: 35 additions & 0 deletions tests/tests_unit/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from contextlib import nullcontext

import pytest
from pydantic.error_wrappers import ValidationError

from nebari import schema
from nebari.plugins import nebari_plugin_manager

Expand Down Expand Up @@ -48,3 +53,33 @@ def test_render_schema(nebari_config):
assert isinstance(nebari_config, schema.Main)
assert nebari_config.project_name == f"pytest{nebari_config.provider.value}"
assert nebari_config.namespace == "dev"


@pytest.mark.parametrize(
"provider, exception",
[
(
"fake",
pytest.raises(
ValueError,
match="'fake' is not a valid enumeration member; permitted: local, existing, do, aws, gcp, azure",
),
),
("aws", pytest.raises(ValidationError)),
("gcp", pytest.raises(ValidationError)),
("do", pytest.raises(ValidationError)),
("azure", pytest.raises(ValidationError)),
fangchenli marked this conversation as resolved.
Show resolved Hide resolved
("existing", nullcontext()),
("local", nullcontext()),
],
)
def test_provider_validation(config_schema, provider, exception):
# TODO: for cloud providers, we are currently not testing the expected behaviours,
# there should be no validation error for aws, gcp, do, azure.
config_dict = {
"project_name": "test",
"provider": f"{provider}",
}
with exception:
config = config_schema(**config_dict)
assert config.provider == provider
Loading