From 02810cad05df6a669b3e09e1e008cf344a4309de Mon Sep 17 00:00:00 2001 From: dcfidalgo Date: Mon, 17 Jan 2022 17:04:19 +0100 Subject: [PATCH 1/5] feat: add warning when only agent is provided --- src/rubrix/client/models.py | 143 ++++++++++++++++++------------------ 1 file changed, 71 insertions(+), 72 deletions(-) diff --git a/src/rubrix/client/models.py b/src/rubrix/client/models.py index b1465b71c8..9b4e542890 100644 --- a/src/rubrix/client/models.py +++ b/src/rubrix/client/models.py @@ -27,6 +27,74 @@ from rubrix.server.commons.helpers import limit_value_length +class BaseRecord(BaseModel): + """Base class for our record models + + Args: + prediction: + The predictions for your record. + annotation: + The annotations for your record. + prediction_agent: + Name of the prediction agent. By default, this is set to the hostname of your machine. + annotation_agent: + Name of the prediction agent. By default, this is set to the hostname of your machine. + id: + The id of the record. By default (`None`), we will generate a unique ID for you. + metadata: + Meta data for the record. Defaults to `{}`. + status: + The status of the record. Options: 'Default', 'Edited', 'Discarded', 'Validated'. + If an annotation is provided, this defaults to 'Validated', otherwise 'Default'. + event_timestamp: + The timestamp of the record. + metrics: + READ ONLY! Metrics at record level provided by the server when using `rb.load`. + This attribute will be ignored when using `rb.log`. + """ + + prediction: Optional[Any] = None + annotation: Optional[Any] = None + prediction_agent: Optional[str] = None + annotation_agent: Optional[str] = None + id: Optional[Union[int, str]] = None + metadata: Dict[str, Any] = Field(default_factory=dict) + status: Optional[str] = None + event_timestamp: Optional[datetime.datetime] = None + metrics: Optional[Dict[str, Any]] = None + + @validator("metadata", pre=True) + def check_value_length(cls, metadata): + return _limit_metadata_values(metadata) + + def __init__(self, *args, **kwargs): + """Custom init to handle dynamic defaults""" + # noinspection PyArgumentList + super().__init__(*args, **kwargs) + self.status = self.status or ( + "Default" if self.annotation is None else "Validated" + ) + if self.annotation_agent is not None and self.annotation is None: + warnings.warn( + "You provided an `annotation_agent`, but no `annotation`. The `annotation_agent` will not be logged to the server." + ) + if self.prediction_agent is not None and self.prediction is None: + warnings.warn( + "You provided an `prediction_agent`, but no `prediction`. The `prediction_agent` will not be logged to the server." + ) + + +def _limit_metadata_values(metadata: Dict[str, Any]) -> Dict[str, Any]: + """Checks metadata values length and apply value truncation for large values""" + new_value = limit_value_length(metadata, max_length=MAX_KEYWORD_LENGTH) + if new_value != metadata: + warnings.warn( + "Some metadata values exceed the max length. " + f"Those values will be truncated by keeping only the last {MAX_KEYWORD_LENGTH} characters." + ) + return new_value + + class BulkResponse(BaseModel): """Summary response when logging records to the Rubrix server. @@ -55,7 +123,7 @@ class TokenAttributions(BaseModel): attributions: Dict[str, float] = Field(default_factory=dict) -class TextClassificationRecord(BaseModel): +class TextClassificationRecord(BaseRecord): """Record for text classification Args: @@ -100,18 +168,10 @@ class TextClassificationRecord(BaseModel): prediction: Optional[List[Tuple[str, float]]] = None annotation: Optional[Union[str, List[str]]] = None - prediction_agent: Optional[str] = None - annotation_agent: Optional[str] = None multi_label: bool = False explanation: Optional[Dict[str, List[TokenAttributions]]] = None - id: Optional[Union[int, str]] = None - metadata: Dict[str, Any] = Field(default_factory=dict) - status: Optional[str] = None - event_timestamp: Optional[datetime.datetime] = None - metrics: Optional[Dict[str, Any]] = None - @validator("inputs", pre=True) def input_as_dict(cls, inputs): """Preprocess record inputs and wraps as dictionary if needed""" @@ -119,20 +179,8 @@ def input_as_dict(cls, inputs): return inputs return dict(text=inputs) - @validator("metadata", pre=True) - def check_value_length(cls, metadata): - return _limit_metadata_values(metadata) - - def __init__(self, *args, **kwargs): - """Custom init to handle dynamic defaults""" - # noinspection PyArgumentList - super().__init__(*args, **kwargs) - self.status = self.status or ( - "Default" if self.annotation is None else "Validated" - ) - -class TokenClassificationRecord(BaseModel): +class TokenClassificationRecord(BaseRecord): """Record for a token classification task Args: @@ -181,28 +229,9 @@ class TokenClassificationRecord(BaseModel): List[Union[Tuple[str, int, int], Tuple[str, int, int, float]]] ] = None annotation: Optional[List[Tuple[str, int, int]]] = None - prediction_agent: Optional[str] = None - annotation_agent: Optional[str] = None - - id: Optional[Union[int, str]] = None - metadata: Dict[str, Any] = Field(default_factory=dict) - status: Optional[str] = None - event_timestamp: Optional[datetime.datetime] = None - metrics: Optional[Dict[str, Any]] = None - - @validator("metadata", pre=True) - def check_value_length(cls, metadata): - return _limit_metadata_values(metadata) - - def __init__(self, *args, **kwargs): - """Custom init to handle dynamic defaults""" - super().__init__(*args, **kwargs) - self.status = self.status or ( - "Default" if self.annotation is None else "Validated" - ) -class Text2TextRecord(BaseModel): +class Text2TextRecord(BaseRecord): """Record for a text to text task Args: @@ -242,14 +271,6 @@ class Text2TextRecord(BaseModel): prediction: Optional[List[Union[str, Tuple[str, float]]]] = None annotation: Optional[str] = None - prediction_agent: Optional[str] = None - annotation_agent: Optional[str] = None - - id: Optional[Union[int, str]] = None - metadata: Dict[str, Any] = Field(default_factory=dict) - status: Optional[str] = None - event_timestamp: Optional[datetime.datetime] = None - metrics: Optional[Dict[str, Any]] = None @validator("prediction") def prediction_as_tuples( @@ -262,27 +283,5 @@ def prediction_as_tuples( return prediction return [(text, 1.0) for text in prediction] - @validator("metadata", pre=True) - def check_value_length(cls, metadata): - return _limit_metadata_values(metadata) - - def __init__(self, *args, **kwargs): - """Custom init to handle dynamic defaults""" - super().__init__(*args, **kwargs) - self.status = self.status or ( - "Default" if self.annotation is None else "Validated" - ) - - -def _limit_metadata_values(metadata: Dict[str, Any]) -> Dict[str, Any]: - """Checks metadata values length and apply value truncation for large values""" - new_value = limit_value_length(metadata, max_length=MAX_KEYWORD_LENGTH) - if new_value != metadata: - warnings.warn( - "Some metadata values exceed the max length. " - f"Those values will be truncated by keeping only the last {MAX_KEYWORD_LENGTH} characters." - ) - return new_value - Record = Union[TextClassificationRecord, TokenClassificationRecord, Text2TextRecord] From f8ca9cbc42c0d7cc488076c8d866393af1336fff Mon Sep 17 00:00:00 2001 From: dcfidalgo Date: Mon, 17 Jan 2022 17:04:58 +0100 Subject: [PATCH 2/5] refactor: avoid global import --- src/rubrix/__init__.py | 11 ++++++++-- .../token_classification/service/service.py | 20 ++++++------------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/rubrix/__init__.py b/src/rubrix/__init__.py index 51199ba106..737d735811 100644 --- a/src/rubrix/__init__.py +++ b/src/rubrix/__init__.py @@ -21,14 +21,21 @@ import logging import os import re -from typing import Iterable +from typing import Any, Dict, Iterable, List, Optional, Union import pandas import pkg_resources from rubrix._constants import DEFAULT_API_KEY from rubrix.client import RubrixClient -from rubrix.client.models import * +from rubrix.client.models import ( + BulkResponse, + Record, + Text2TextRecord, + TextClassificationRecord, + TokenAttributions, + TokenClassificationRecord, +) from rubrix.monitoring.model_monitor import monitor try: diff --git a/src/rubrix/server/tasks/token_classification/service/service.py b/src/rubrix/server/tasks/token_classification/service/service.py index d9878bcb4b..8297498bcb 100644 --- a/src/rubrix/server/tasks/token_classification/service/service.py +++ b/src/rubrix/server/tasks/token_classification/service/service.py @@ -17,27 +17,22 @@ from fastapi import Depends -from rubrix import MAX_KEYWORD_LENGTH -from rubrix.server.commons.es_helpers import ( - aggregations, - sort_by2elasticsearch, -) +from rubrix._constants import MAX_KEYWORD_LENGTH +from rubrix.server.commons.es_helpers import aggregations, sort_by2elasticsearch from rubrix.server.datasets.model import Dataset from rubrix.server.tasks.commons import ( BulkResponse, EsRecordDataFieldNames, SortableField, ) -from rubrix.server.tasks.commons.dao import ( - extends_index_properties, -) +from rubrix.server.tasks.commons.dao import extends_index_properties from rubrix.server.tasks.commons.dao.dao import DatasetRecordsDAO, dataset_records_dao from rubrix.server.tasks.commons.dao.model import RecordSearch from rubrix.server.tasks.commons.metrics.service import MetricsService from rubrix.server.tasks.token_classification.api.model import ( - CreationTokenClassificationRecord, MENTIONS_ES_FIELD_NAME, PREDICTED_MENTIONS_ES_FIELD_NAME, + CreationTokenClassificationRecord, TokenClassificationAggregations, TokenClassificationQuery, TokenClassificationRecord, @@ -177,7 +172,7 @@ def search( ), size=size, record_from=record_from, - exclude_fields=["metrics"] if exclude_metrics else None + exclude_fields=["metrics"] if exclude_metrics else None, ) return TokenClassificationSearchResults( total=results.total, @@ -239,8 +234,5 @@ def token_classification_service( """ global _instance if not _instance: - _instance = TokenClassificationService( - dao=dao, - metrics=metrics - ) + _instance = TokenClassificationService(dao=dao, metrics=metrics) return _instance From b1b37a825ab9347d719eb3fac1f39f1966eb1c87 Mon Sep 17 00:00:00 2001 From: dcfidalgo Date: Mon, 17 Jan 2022 17:05:28 +0100 Subject: [PATCH 3/5] docs: remove not important members from docs --- docs/reference/python/python_client.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/reference/python/python_client.rst b/docs/reference/python/python_client.rst index a6d816ac51..c5c3019121 100644 --- a/docs/reference/python/python_client.rst +++ b/docs/reference/python/python_client.rst @@ -19,3 +19,4 @@ Models .. automodule:: rubrix.client.models :members: + :exclude-members: BaseRecord, BulkResponse From 2429d984a3cf11cc4e589be32e26b2c5d0b2096e Mon Sep 17 00:00:00 2001 From: dcfidalgo Date: Mon, 17 Jan 2022 17:29:29 +0100 Subject: [PATCH 4/5] test: add test --- tests/client/test_models.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/tests/client/test_models.py b/tests/client/test_models.py index a14c06ff0b..8adf3e3b88 100644 --- a/tests/client/test_models.py +++ b/tests/client/test_models.py @@ -17,10 +17,14 @@ import numpy import pytest from pydantic import ValidationError -from rubrix._constants import MAX_KEYWORD_LENGTH -from rubrix.client.models import Text2TextRecord, TextClassificationRecord -from rubrix.client.models import TokenClassificationRecord +from rubrix._constants import MAX_KEYWORD_LENGTH +from rubrix.client.models import ( + BaseRecord, + Text2TextRecord, + TextClassificationRecord, + TokenClassificationRecord, +) @pytest.mark.parametrize( @@ -92,3 +96,14 @@ def test_model_serialization_with_numpy_nan(): ) json_record = json.loads(record.json()) + + +def test_warning_when_only_agent(): + with pytest.warns( + UserWarning, match="`prediction_agent` will not be logged to the server." + ): + BaseRecord(prediction_agent="mock") + with pytest.warns( + UserWarning, match="`annotation_agent` will not be logged to the server." + ): + BaseRecord(annotation_agent="mock") From 1910688ed8e4df381e750142c960ecc7bf1dfc4e Mon Sep 17 00:00:00 2001 From: dcfidalgo Date: Tue, 18 Jan 2022 23:15:41 +0100 Subject: [PATCH 5/5] refactor: use root_validators instead of init --- src/rubrix/client/models.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/rubrix/client/models.py b/src/rubrix/client/models.py index 9b4e542890..fef4804eb0 100644 --- a/src/rubrix/client/models.py +++ b/src/rubrix/client/models.py @@ -21,7 +21,7 @@ import warnings from typing import Any, Dict, List, Optional, Tuple, Union -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, root_validator, validator from rubrix._constants import MAX_KEYWORD_LENGTH from rubrix.server.commons.helpers import limit_value_length @@ -67,21 +67,32 @@ class BaseRecord(BaseModel): def check_value_length(cls, metadata): return _limit_metadata_values(metadata) - def __init__(self, *args, **kwargs): - """Custom init to handle dynamic defaults""" - # noinspection PyArgumentList - super().__init__(*args, **kwargs) - self.status = self.status or ( - "Default" if self.annotation is None else "Validated" - ) - if self.annotation_agent is not None and self.annotation is None: + @root_validator + def _check_agents(cls, values): + """Triggers a warning when ONLY agents are provided""" + if ( + values.get("annotation_agent") is not None + and values.get("annotation") is None + ): warnings.warn( "You provided an `annotation_agent`, but no `annotation`. The `annotation_agent` will not be logged to the server." ) - if self.prediction_agent is not None and self.prediction is None: + if ( + values.get("prediction_agent") is not None + and values.get("prediction") is None + ): warnings.warn( "You provided an `prediction_agent`, but no `prediction`. The `prediction_agent` will not be logged to the server." ) + return values + + @root_validator + def _check_and_update_status(cls, values): + """Updates the status if an annotation is provided and no status is specified.""" + values["status"] = values.get("status") or ( + "Default" if values.get("annotation") is None else "Validated" + ) + return values def _limit_metadata_values(metadata: Dict[str, Any]) -> Dict[str, Any]: