Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TST: Add test to ensure all optional fields have defaults #834

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@

import fmu.dataio as dio
from fmu.config import utilities as ut
from fmu.dataio._model import fields, global_configuration
from fmu.dataio._model import Root, fields, global_configuration
from fmu.dataio.dataio import ExportData, read_metadata
from fmu.dataio.providers._fmu import FmuEnv

from .utils import _metadata_examples
from .utils import _get_nested_pydantic_models, _metadata_examples

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -699,3 +699,9 @@ def fixture_drogon_volumes(rootpath):
rootpath / "tests/data/drogon/tabular/geogrid--vol.csv",
)
)


@pytest.fixture(scope="session")
def pydantic_models_from_root():
"""Return all nested pydantic models from Root and downwards"""
return _get_nested_pydantic_models(Root)
17 changes: 17 additions & 0 deletions tests/test_schema/test_pydantic_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from copy import deepcopy
from typing import get_args

import pytest
from pydantic import ValidationError
Expand Down Expand Up @@ -32,6 +33,22 @@ def test_validate(file, example):
Root.model_validate(example)


def test_for_optional_fields_without_default(pydantic_models_from_root):
"""Test that all optional fields have a default value"""
optionals_without_default = []
for model in pydantic_models_from_root:
for field_name, field_info in model.model_fields.items():
if (
type(None) in get_args(field_info.annotation)
and field_info.is_required()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did not know about is_required(), neat

):
optionals_without_default.append(
f"{model.__module__}.{model.__name__}.{field_name}"
)

assert not optionals_without_default


def test_schema_file_block(metadata_examples):
"""Test variations on the file block."""

Expand Down
30 changes: 29 additions & 1 deletion tests/test_units/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
import os
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Dict, List, Optional, Union

import numpy as np
import pytest
from xtgeo import Grid, Polygons, RegularSurface

from fmu.dataio import _utils as utils
from fmu.dataio._model import fields

from ..utils import inside_rms
from ..utils import _get_pydantic_models_from_annotation, inside_rms


@pytest.mark.parametrize(
Expand Down Expand Up @@ -148,3 +150,29 @@ def test_read_named_envvar():

os.environ["MYTESTENV"] = "mytestvalue"
assert utils.read_named_envvar("MYTESTENV") == "mytestvalue"


def test_get_pydantic_models_from_annotation():
annotation = Union[List[fields.Access], fields.File]
assert _get_pydantic_models_from_annotation(annotation) == [
fields.Access,
fields.File,
]
annotation = Optional[Union[Dict[str, fields.Access], List[fields.File]]]
assert _get_pydantic_models_from_annotation(annotation) == [
fields.Access,
fields.File,
]

annotation = List[Union[fields.Access, fields.File, fields.Tracklog]]
assert _get_pydantic_models_from_annotation(annotation) == [
fields.Access,
fields.File,
fields.Tracklog,
]

annotation = List[List[List[List[fields.Tracklog]]]]
assert _get_pydantic_models_from_annotation(annotation) == [fields.Tracklog]

annotation = Union[str, List[int], Dict[str, int]]
assert not _get_pydantic_models_from_annotation(annotation)
29 changes: 29 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

import datetime
from functools import wraps
from pathlib import Path
from typing import Any, get_args

import pytest
import yaml
from pydantic import BaseModel


def inside_rms(func):
Expand Down Expand Up @@ -43,3 +47,28 @@ def _metadata_examples():
path.name: _isoformat_all_datetimes(_parse_yaml(path))
for path in Path(".").absolute().glob("schema/definitions/0.8.0/examples/*.yml")
}


def _get_pydantic_models_from_annotation(annotation: Any) -> list[Any]:
"""
Get a list of all pydantic models defined inside an annotation.
Example: Union[Model1, list[dict[str, Model2]]] returns [Model1, Model2]
"""
if isinstance(annotation, type(BaseModel)):
return [annotation]

annotations = []
for ann in get_args(annotation):
annotations += _get_pydantic_models_from_annotation(ann)
return annotations


def _get_nested_pydantic_models(model: type[BaseModel]) -> set[type[BaseModel]]:
"""Get a set of all nested pydantic models from a pydantic model"""
models = {model}

for field_info in model.model_fields.values():
for model in _get_pydantic_models_from_annotation(field_info.annotation):
if model not in models:
models.update(_get_nested_pydantic_models(model))
return models