From 1910688ed8e4df381e750142c960ecc7bf1dfc4e Mon Sep 17 00:00:00 2001 From: dcfidalgo Date: Tue, 18 Jan 2022 23:15:41 +0100 Subject: [PATCH] refactor: use root_validators instead of init --- src/rubrix/client/models.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/rubrix/client/models.py b/src/rubrix/client/models.py index 9b4e542890..fef4804eb0 100644 --- a/src/rubrix/client/models.py +++ b/src/rubrix/client/models.py @@ -21,7 +21,7 @@ import warnings from typing import Any, Dict, List, Optional, Tuple, Union -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, root_validator, validator from rubrix._constants import MAX_KEYWORD_LENGTH from rubrix.server.commons.helpers import limit_value_length @@ -67,21 +67,32 @@ class BaseRecord(BaseModel): def check_value_length(cls, metadata): return _limit_metadata_values(metadata) - def __init__(self, *args, **kwargs): - """Custom init to handle dynamic defaults""" - # noinspection PyArgumentList - super().__init__(*args, **kwargs) - self.status = self.status or ( - "Default" if self.annotation is None else "Validated" - ) - if self.annotation_agent is not None and self.annotation is None: + @root_validator + def _check_agents(cls, values): + """Triggers a warning when ONLY agents are provided""" + if ( + values.get("annotation_agent") is not None + and values.get("annotation") is None + ): warnings.warn( "You provided an `annotation_agent`, but no `annotation`. The `annotation_agent` will not be logged to the server." ) - if self.prediction_agent is not None and self.prediction is None: + if ( + values.get("prediction_agent") is not None + and values.get("prediction") is None + ): warnings.warn( "You provided an `prediction_agent`, but no `prediction`. The `prediction_agent` will not be logged to the server." ) + return values + + @root_validator + def _check_and_update_status(cls, values): + """Updates the status if an annotation is provided and no status is specified.""" + values["status"] = values.get("status") or ( + "Default" if values.get("annotation") is None else "Validated" + ) + return values def _limit_metadata_values(metadata: Dict[str, Any]) -> Dict[str, Any]: