Skip to content

Commit

Permalink
refactor: use root_validators instead of init
Browse files Browse the repository at this point in the history
  • Loading branch information
dcfidalgo committed Jan 18, 2022
1 parent 2429d98 commit 1910688
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions src/rubrix/client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, root_validator, validator

from rubrix._constants import MAX_KEYWORD_LENGTH
from rubrix.server.commons.helpers import limit_value_length
Expand Down Expand Up @@ -67,21 +67,32 @@ class BaseRecord(BaseModel):
def check_value_length(cls, metadata):
return _limit_metadata_values(metadata)

def __init__(self, *args, **kwargs):
"""Custom init to handle dynamic defaults"""
# noinspection PyArgumentList
super().__init__(*args, **kwargs)
self.status = self.status or (
"Default" if self.annotation is None else "Validated"
)
if self.annotation_agent is not None and self.annotation is None:
@root_validator
def _check_agents(cls, values):
"""Triggers a warning when ONLY agents are provided"""
if (
values.get("annotation_agent") is not None
and values.get("annotation") is None
):
warnings.warn(
"You provided an `annotation_agent`, but no `annotation`. The `annotation_agent` will not be logged to the server."
)
if self.prediction_agent is not None and self.prediction is None:
if (
values.get("prediction_agent") is not None
and values.get("prediction") is None
):
warnings.warn(
"You provided an `prediction_agent`, but no `prediction`. The `prediction_agent` will not be logged to the server."
)
return values

@root_validator
def _check_and_update_status(cls, values):
"""Updates the status if an annotation is provided and no status is specified."""
values["status"] = values.get("status") or (
"Default" if values.get("annotation") is None else "Validated"
)
return values


def _limit_metadata_values(metadata: Dict[str, Any]) -> Dict[str, Any]:
Expand Down

0 comments on commit 1910688

Please sign in to comment.