Skip to content

Commit

Permalink
use fields and preview models from common
Browse files Browse the repository at this point in the history
  • Loading branch information
cutoffthetop committed Dec 6, 2024
1 parent ed69fcc commit b7a5dcc
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 212 deletions.
173 changes: 1 addition & 172 deletions mex/backend/fields.py
Original file line number Diff line number Diff line change
@@ -1,124 +1,4 @@
from collections.abc import Callable, Generator, Mapping
from types import NoneType, UnionType
from typing import Annotated, Any, Union, get_args, get_origin

from mex.common.models import (
ADDITIVE_MODEL_CLASSES_BY_NAME,
EXTRACTED_MODEL_CLASSES_BY_NAME,
MERGED_MODEL_CLASSES_BY_NAME,
PREVENTIVE_MODEL_CLASSES_BY_NAME,
SUBTRACTIVE_MODEL_CLASSES_BY_NAME,
BaseModel,
GenericFieldInfo,
)
from mex.common.types import MERGED_IDENTIFIER_CLASSES, Link, LiteralStringType, Text


def _get_inner_types(annotation: Any) -> Generator[type, None, None]:
"""Yield all inner types from unions, lists and type annotations (except NoneType).
Args:
annotation: A valid python type annotation
Returns:
A generator for all (non-NoneType) types found in the annotation
"""
if get_origin(annotation) == Annotated:
yield from _get_inner_types(get_args(annotation)[0])
elif get_origin(annotation) in (Union, UnionType, list):
for arg in get_args(annotation):
yield from _get_inner_types(arg)
elif annotation not in (None, NoneType):
yield annotation


def _contains_only_types(field: GenericFieldInfo, *types: type) -> bool:
"""Return whether a `field` is annotated as one of the given `types`.
Unions, lists and type annotations are checked for their inner types and only the
non-`NoneType` types are considered for the type-check.
Args:
field: A `GenericFieldInfo` instance
types: Types to look for in the field's annotation
Returns:
Whether the field contains any of the given types
"""
if inner_types := list(_get_inner_types(field.annotation)):
return all(inner_type in types for inner_type in inner_types)
return False


def _group_fields_by_class_name(
model_classes_by_name: Mapping[str, type[BaseModel]],
predicate: Callable[[GenericFieldInfo], bool],
) -> dict[str, list[str]]:
"""Group the field names by model class and filter them by the given predicate.
Args:
model_classes_by_name: Map from class names to model classes
predicate: Function to filter the fields of the classes by
Returns:
Dictionary mapping class names to a list of field names filtered by `predicate`
"""
return {
name: sorted(
{
field_name
for field_name, field_info in cls.get_all_fields().items()
if predicate(field_info)
}
)
for name, cls in model_classes_by_name.items()
}


# all models classes
ALL_MODEL_CLASSES_BY_NAME = {
**ADDITIVE_MODEL_CLASSES_BY_NAME,
**EXTRACTED_MODEL_CLASSES_BY_NAME,
**MERGED_MODEL_CLASSES_BY_NAME,
**PREVENTIVE_MODEL_CLASSES_BY_NAME,
**SUBTRACTIVE_MODEL_CLASSES_BY_NAME,
}

# fields that are immutable and can only be set once
FROZEN_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name(
ALL_MODEL_CLASSES_BY_NAME,
lambda field_info: field_info.frozen is True,
)

# static fields that are set once on class-level to a literal type
LITERAL_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name(
ALL_MODEL_CLASSES_BY_NAME,
lambda field_info: isinstance(field_info.annotation, LiteralStringType),
)

# fields typed as merged identifiers containing references to merged items
REFERENCE_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name(
ALL_MODEL_CLASSES_BY_NAME,
lambda field_info: _contains_only_types(field_info, *MERGED_IDENTIFIER_CLASSES),
)

# nested fields that contain `Text` objects
TEXT_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name(
ALL_MODEL_CLASSES_BY_NAME,
lambda field_info: _contains_only_types(field_info, Text),
)

# nested fields that contain `Link` objects
LINK_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name(
ALL_MODEL_CLASSES_BY_NAME,
lambda field_info: _contains_only_types(field_info, Link),
)

# fields annotated as `str` type
STRING_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name(
ALL_MODEL_CLASSES_BY_NAME,
lambda field_info: _contains_only_types(field_info, str),
)
from mex.common.fields import STRING_FIELDS_BY_CLASS_NAME

# fields that should be indexed as searchable fields
SEARCHABLE_FIELDS = sorted(
Expand All @@ -133,54 +13,3 @@ def _group_fields_by_class_name(
SEARCHABLE_CLASSES = sorted(
{name for name, field_names in STRING_FIELDS_BY_CLASS_NAME.items() if field_names}
)

# fields with changeable values that are not nested objects or merged item references
MUTABLE_FIELDS_BY_CLASS_NAME = {
name: sorted(
{
field_name
for field_name in cls.get_all_fields()
if field_name
not in (
*FROZEN_FIELDS_BY_CLASS_NAME[name],
*REFERENCE_FIELDS_BY_CLASS_NAME[name],
*TEXT_FIELDS_BY_CLASS_NAME[name],
*LINK_FIELDS_BY_CLASS_NAME[name],
)
}
)
for name, cls in ALL_MODEL_CLASSES_BY_NAME.items()
}

# fields with mergeable values that are neither literal nor frozen
MERGEABLE_FIELDS_BY_CLASS_NAME = {
name: sorted(
{
field_name
for field_name in cls.model_fields
if field_name
not in (
*FROZEN_FIELDS_BY_CLASS_NAME[name],
*LITERAL_FIELDS_BY_CLASS_NAME[name],
)
}
)
for name, cls in MERGED_MODEL_CLASSES_BY_NAME.items()
}

# fields with values that should be set once but are neither literal nor references
FINAL_FIELDS_BY_CLASS_NAME = {
name: sorted(
{
field_name
for field_name in cls.get_all_fields()
if field_name in FROZEN_FIELDS_BY_CLASS_NAME[name]
and field_name
not in (
*LITERAL_FIELDS_BY_CLASS_NAME[name],
*REFERENCE_FIELDS_BY_CLASS_NAME[name],
)
}
)
for name, cls in ALL_MODEL_CLASSES_BY_NAME.items()
}
17 changes: 8 additions & 9 deletions mex/backend/graph/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,20 @@
from neo4j import Driver, GraphDatabase, NotificationDisabledCategory
from pydantic import Field

from mex.backend.fields import (
FINAL_FIELDS_BY_CLASS_NAME,
LINK_FIELDS_BY_CLASS_NAME,
MUTABLE_FIELDS_BY_CLASS_NAME,
REFERENCE_FIELDS_BY_CLASS_NAME,
SEARCHABLE_CLASSES,
SEARCHABLE_FIELDS,
TEXT_FIELDS_BY_CLASS_NAME,
)
from mex.backend.fields import SEARCHABLE_CLASSES, SEARCHABLE_FIELDS
from mex.backend.graph.models import Result
from mex.backend.graph.query import QueryBuilder
from mex.backend.graph.transform import expand_references_in_search_result
from mex.backend.settings import BackendSettings
from mex.common.connector import BaseConnector
from mex.common.exceptions import MExError
from mex.common.fields import (
FINAL_FIELDS_BY_CLASS_NAME,
LINK_FIELDS_BY_CLASS_NAME,
MUTABLE_FIELDS_BY_CLASS_NAME,
REFERENCE_FIELDS_BY_CLASS_NAME,
TEXT_FIELDS_BY_CLASS_NAME,
)
from mex.common.logging import logger
from mex.common.models import (
EXTRACTED_MODEL_CLASSES_BY_NAME,
Expand Down
27 changes: 14 additions & 13 deletions mex/backend/merged/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,24 @@

from pydantic import Field, TypeAdapter, ValidationError

from mex.backend.fields import MERGEABLE_FIELDS_BY_CLASS_NAME
from mex.backend.graph.connector import GraphConnector
from mex.backend.graph.exceptions import InconsistentGraphError
from mex.backend.merged.models import MergedItemSearch, PreviewItemSearch
from mex.backend.rules.helpers import transform_raw_rules_to_rule_set_response
from mex.backend.utils import extend_list_in_dict, prune_list_in_dict
from mex.common.exceptions import MExError
from mex.common.fields import MERGEABLE_FIELDS_BY_CLASS_NAME
from mex.common.logging import logger
from mex.common.models import (
ADDITIVE_MODEL_CLASSES_BY_NAME,
EXTRACTED_MODEL_CLASSES_BY_NAME,
MERGED_MODEL_CLASSES_BY_NAME,
PREVIEW_MODEL_CLASSES_BY_NAME,
RULE_MODEL_CLASSES_BY_NAME,
AnyAdditiveModel,
AnyExtractedModel,
AnyMergedModel,
AnyPreventiveModel,
AnyPreviewModel,
AnyRuleSetRequest,
AnyRuleSetResponse,
AnySubtractiveModel,
Expand Down Expand Up @@ -98,7 +99,7 @@ def create_merged_item(
extracted_items: list[AnyExtractedModel],
rule_set: AnyRuleSetRequest | AnyRuleSetResponse | None,
validate_cardinality: Literal[False],
) -> AnyAdditiveModel: ...
) -> AnyPreviewModel: ...


@overload
Expand All @@ -115,7 +116,7 @@ def create_merged_item(
extracted_items: list[AnyExtractedModel],
rule_set: AnyRuleSetRequest | AnyRuleSetResponse | None,
validate_cardinality: Literal[True, False],
) -> AnyAdditiveModel | AnyMergedModel:
) -> AnyPreviewModel | AnyMergedModel:
"""Merge a list of extracted items with a set of rules.
Args:
Expand All @@ -130,17 +131,17 @@ def create_merged_item(
InconsistentGraphError: When the graph response cannot be parsed
Returns:
Instance of a merged item
Instance of a merged or preview item
"""
model_class_lookup: (
dict[str, type[AnyAdditiveModel]] | dict[str, type[AnyMergedModel]]
dict[str, type[AnyPreviewModel]] | dict[str, type[AnyMergedModel]]
)
if validate_cardinality:
model_prefix = "Merged"
model_class_lookup = MERGED_MODEL_CLASSES_BY_NAME
else:
model_prefix = "Additive"
model_class_lookup = ADDITIVE_MODEL_CLASSES_BY_NAME
model_prefix = "Preview"
model_class_lookup = PREVIEW_MODEL_CLASSES_BY_NAME

if rule_set:
entity_type = ensure_prefix(rule_set.stemType, model_prefix)
Expand Down Expand Up @@ -170,7 +171,7 @@ def create_merged_item(
def merge_search_result_item(
item: dict[str, Any],
validate_cardinality: Literal[False],
) -> AnyAdditiveModel: ...
) -> AnyPreviewModel: ...


@overload
Expand All @@ -183,7 +184,7 @@ def merge_search_result_item(
def merge_search_result_item(
item: dict[str, Any],
validate_cardinality: Literal[True, False],
) -> AnyAdditiveModel | AnyMergedModel:
) -> AnyPreviewModel | AnyMergedModel:
"""Merge a single search result into a merged item.
Args:
Expand All @@ -196,7 +197,7 @@ def merge_search_result_item(
InconsistentGraphError: When the graph response item has inconsistencies
Returns:
AnyMergedModel instance
Instance of a merged or preview item
"""
extracted_items = [
EXTRACTED_MODEL_ADAPTER.validate_python(component)
Expand Down Expand Up @@ -267,7 +268,7 @@ def search_merged_items_in_graph( # noqa: PLR0913
InconsistentGraphError: When the graph response has inconsistencies
Returns:
MergedItemSearch instance
Search response for preview or merged items
"""
graph = GraphConnector.get()
result = graph.fetch_merged_items(
Expand All @@ -278,7 +279,7 @@ def search_merged_items_in_graph( # noqa: PLR0913
limit=limit,
)
total: int = result["total"]
items: list[AnyMergedModel | AnyAdditiveModel] = []
items: list[AnyPreviewModel | AnyMergedModel] = []
for item in result["items"]:
try:
items.append(
Expand Down
4 changes: 2 additions & 2 deletions mex/backend/merged/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import Field

from mex.common.models import AnyAdditiveModel, AnyMergedModel, BaseModel
from mex.common.models import AnyMergedModel, AnyPreviewModel, BaseModel


class MergedItemSearch(BaseModel):
Expand All @@ -15,5 +15,5 @@ class MergedItemSearch(BaseModel):
class PreviewItemSearch(BaseModel):
"""Response body for the preview item search endpoint."""

items: list[Annotated[AnyAdditiveModel, Field(discriminator="entityType")]]
items: list[Annotated[AnyPreviewModel, Field(discriminator="entityType")]]
total: int
6 changes: 3 additions & 3 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies = [
"fastapi>=0.115,<1",
"httpx>=0.27,<1",
"jinja2>=3,<4",
"mex-common @ git+https://github.com/robert-koch-institut/mex-common.git@0.42.0",
"mex-common @ git+https://github.com/robert-koch-institut/mex-common.git@feature/mx-1649-add-preview-models",
"neo4j>=5,<6",
"pydantic>=2,<3",
"starlette>=0.41,<1",
Expand Down
Loading

0 comments on commit b7a5dcc

Please sign in to comment.