Skip to content

Commit

Permalink
fix up Datum, DFTFunctional, CPUInfo, and serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
loriab committed Sep 10, 2024
1 parent 842eace commit 8e5e9f5
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 31 deletions.
2 changes: 1 addition & 1 deletion qcelemental/datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from decimal import Decimal
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import numpy as np
from pydantic import (
Expand Down
4 changes: 3 additions & 1 deletion qcelemental/info/cpu_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from pydantic import BeforeValidator, Field
from typing_extensions import Annotated

from ..models import ProtoModel
from ..models.v2 import ProtoModel

# ProcessorInfo models don't become parts of QCSchema models afaik, so pure pydantic v2 API


class VendorEnum(str, Enum):
Expand Down
4 changes: 3 additions & 1 deletion qcelemental/info/dft_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from pydantic import Field
from typing_extensions import Annotated

from ..models import ProtoModel
from ..models.v2 import ProtoModel

# DFTFunctional models don't become parts of QCSchema models afaik, so pure pydantic v2 API


class DFTFunctionalInfo(ProtoModel):
Expand Down
20 changes: 11 additions & 9 deletions qcelemental/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union

import numpy as np

try:
from pydantic.v1 import BaseModel
except ImportError: # Will also trap ModuleNotFoundError
from pydantic import BaseModel
import pydantic

if TYPE_CHECKING:
from qcelemental.models import ProtoModel # TODO: recheck if .v1 needed
Expand Down Expand Up @@ -313,10 +309,16 @@ def _compare_recursive(expected, computed, atol, rtol, _prefix=False, equal_phas
prefix = name + "."

# Initial conversions if required
if isinstance(expected, BaseModel):
if isinstance(expected, pydantic.BaseModel):
expected = expected.model_dump()

if isinstance(computed, pydantic.BaseModel):
computed = computed.model_dump()

if isinstance(expected, pydantic.v1.BaseModel):
expected = expected.dict()

if isinstance(computed, BaseModel):
if isinstance(computed, pydantic.v1.BaseModel):
computed = computed.dict()

if isinstance(expected, (str, int, bool, complex)):
Expand Down Expand Up @@ -381,8 +383,8 @@ def _compare_recursive(expected, computed, atol, rtol, _prefix=False, equal_phas


def compare_recursive(
expected: Union[Dict, BaseModel, "ProtoModel"], # type: ignore
computed: Union[Dict, BaseModel, "ProtoModel"], # type: ignore
expected: Union[Dict, pydantic.BaseModel, pydantic.v1.BaseModel, "ProtoModel"], # type: ignore
computed: Union[Dict, pydantic.BaseModel, pydantic.v1.BaseModel, "ProtoModel"], # type: ignore
label: str = None,
*,
atol: float = 1.0e-6,
Expand Down
10 changes: 3 additions & 7 deletions qcelemental/tests/test_datum.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from decimal import Decimal

import numpy as np

try:
import pydantic.v1 as pydantic
except ImportError: # Will also trap ModuleNotFoundError
import pydantic
import pydantic
import pytest

import qcelemental as qcel
Expand Down Expand Up @@ -46,10 +42,10 @@ def test_creation_nonnum(dataset):


def test_creation_error():
with pytest.raises(pydantic.ValidationError):
with pytest.raises(pydantic.ValidationError) as e:
qcel.Datum("ze lbl", "ze unit", "ze data")

# assert 'Datum data should be float' in str(e)
assert "Datum data should be float" in str(e.value)


@pytest.mark.parametrize(
Expand Down
10 changes: 6 additions & 4 deletions qcelemental/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import qcelemental as qcel
from qcelemental.testing import compare_recursive, compare_values

from .addons import serialize_extensions
from .addons import schema_versions, serialize_extensions


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -313,7 +313,7 @@ def test_serialization(obj, encoding):


@pytest.fixture
def atomic_result():
def atomic_result_data():
"""Mock AtomicResult output which can be tested against for complex serialization methods"""

data = {
Expand Down Expand Up @@ -385,10 +385,12 @@ def atomic_result():
"success": True,
"error": None,
}
return data

yield qcel.models.results.AtomicResult(**data)

def test_json_dumps(atomic_result_data, schema_versions):
AtomicResult = schema_versions.AtomicResult

def test_json_dumps(atomic_result):
atomic_result = AtomicResult(**atomic_result_data)
ret = qcel.util.json_dumps(atomic_result)
assert isinstance(ret, str)
5 changes: 1 addition & 4 deletions qcelemental/util/autodocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@ def is_pydantic(test_object):

def parse_type_str(prop) -> str:
# Import here to minimize issues
try:
from pydantic.v1 import fields
except ImportError: # Will also trap ModuleNotFoundError
from pydantic import fields
from pydantic.v1 import fields

typing_map = {
fields.SHAPE_TUPLE: "Tuple",
Expand Down
25 changes: 21 additions & 4 deletions qcelemental/util/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pydantic
from pydantic.v1.json import pydantic_encoder
from pydantic_core import PydanticSerializationError, to_jsonable_python

from .importing import which_import
Expand Down Expand Up @@ -41,7 +42,14 @@ def msgpackext_encode(obj: Any) -> Any:
try:
return to_jsonable_python(obj)
except ValueError:
pass
# above to_jsonable_python is for Pydantic v2 API models
# below pydatnic_encoder is for Pydantic v1 API models
# tentative whether handling both together will work beyond tests
# or if separate files called by models.v1 and .v2 will be req'd
try:
return pydantic_encoder(obj)
except TypeError:
pass

if isinstance(obj, np.ndarray):
if obj.shape:
Expand Down Expand Up @@ -127,7 +135,10 @@ def default(self, obj: Any) -> Any:
try:
return to_jsonable_python(obj)
except ValueError:
pass
try:
return pydantic_encoder(obj)
except TypeError:
pass

if isinstance(obj, np.ndarray):
if obj.shape:
Expand Down Expand Up @@ -199,7 +210,10 @@ def default(self, obj: Any) -> Any:
try:
return to_jsonable_python(obj)
except ValueError:
pass
try:
return pydantic_encoder(obj)
except TypeError:
pass

# See if pydantic model can be just serialized if the above couldn't be dumped
if isinstance(obj, pydantic.BaseModel):
Expand Down Expand Up @@ -273,7 +287,10 @@ def msgpack_encode(obj: Any) -> Any:
try:
return to_jsonable_python(obj)
except ValueError:
pass
try:
return pydantic_encoder(obj)
except TypeError:
pass

if isinstance(obj, np.ndarray):
if obj.shape:
Expand Down

0 comments on commit 8e5e9f5

Please sign in to comment.