Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(#735): add warning when agent but no prediction/annotation is provided #987

Merged
merged 5 commits into from
Jan 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/reference/python/python_client.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ Models

.. automodule:: rubrix.client.models
:members:
:exclude-members: BaseRecord, BulkResponse
11 changes: 9 additions & 2 deletions src/rubrix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,21 @@
import logging
import os
import re
from typing import Iterable
from typing import Any, Dict, Iterable, List, Optional, Union

import pandas
import pkg_resources

from rubrix._constants import DEFAULT_API_KEY
from rubrix.client import RubrixClient
from rubrix.client.models import *
from rubrix.client.models import (
BulkResponse,
Record,
Text2TextRecord,
TextClassificationRecord,
TokenAttributions,
TokenClassificationRecord,
)
from rubrix.monitoring.model_monitor import monitor

try:
Expand Down
156 changes: 83 additions & 73 deletions src/rubrix/client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,91 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, root_validator, validator

from rubrix._constants import MAX_KEYWORD_LENGTH
from rubrix.server.commons.helpers import limit_value_length


class BaseRecord(BaseModel):
"""Base class for our record models

Args:
prediction:
The predictions for your record.
annotation:
The annotations for your record.
prediction_agent:
Name of the prediction agent. By default, this is set to the hostname of your machine.
annotation_agent:
Name of the prediction agent. By default, this is set to the hostname of your machine.
id:
The id of the record. By default (`None`), we will generate a unique ID for you.
metadata:
Meta data for the record. Defaults to `{}`.
status:
The status of the record. Options: 'Default', 'Edited', 'Discarded', 'Validated'.
If an annotation is provided, this defaults to 'Validated', otherwise 'Default'.
event_timestamp:
The timestamp of the record.
metrics:
READ ONLY! Metrics at record level provided by the server when using `rb.load`.
This attribute will be ignored when using `rb.log`.
"""

prediction: Optional[Any] = None
annotation: Optional[Any] = None
prediction_agent: Optional[str] = None
annotation_agent: Optional[str] = None
id: Optional[Union[int, str]] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
status: Optional[str] = None
event_timestamp: Optional[datetime.datetime] = None
metrics: Optional[Dict[str, Any]] = None

@validator("metadata", pre=True)
def check_value_length(cls, metadata):
return _limit_metadata_values(metadata)

@root_validator
def _check_agents(cls, values):
"""Triggers a warning when ONLY agents are provided"""
if (
values.get("annotation_agent") is not None
and values.get("annotation") is None
):
warnings.warn(
"You provided an `annotation_agent`, but no `annotation`. The `annotation_agent` will not be logged to the server."
)
if (
values.get("prediction_agent") is not None
and values.get("prediction") is None
):
warnings.warn(
"You provided an `prediction_agent`, but no `prediction`. The `prediction_agent` will not be logged to the server."
)
return values

@root_validator
def _check_and_update_status(cls, values):
"""Updates the status if an annotation is provided and no status is specified."""
values["status"] = values.get("status") or (
"Default" if values.get("annotation") is None else "Validated"
)
return values


def _limit_metadata_values(metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Checks metadata values length and apply value truncation for large values"""
new_value = limit_value_length(metadata, max_length=MAX_KEYWORD_LENGTH)
if new_value != metadata:
warnings.warn(
"Some metadata values exceed the max length. "
f"Those values will be truncated by keeping only the last {MAX_KEYWORD_LENGTH} characters."
)
return new_value


class BulkResponse(BaseModel):
"""Summary response when logging records to the Rubrix server.

Expand Down Expand Up @@ -55,7 +134,7 @@ class TokenAttributions(BaseModel):
attributions: Dict[str, float] = Field(default_factory=dict)


class TextClassificationRecord(BaseModel):
class TextClassificationRecord(BaseRecord):
"""Record for text classification

Args:
Expand Down Expand Up @@ -100,39 +179,19 @@ class TextClassificationRecord(BaseModel):

prediction: Optional[List[Tuple[str, float]]] = None
annotation: Optional[Union[str, List[str]]] = None
prediction_agent: Optional[str] = None
annotation_agent: Optional[str] = None
multi_label: bool = False

explanation: Optional[Dict[str, List[TokenAttributions]]] = None

id: Optional[Union[int, str]] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
status: Optional[str] = None
event_timestamp: Optional[datetime.datetime] = None
metrics: Optional[Dict[str, Any]] = None

@validator("inputs", pre=True)
def input_as_dict(cls, inputs):
"""Preprocess record inputs and wraps as dictionary if needed"""
if isinstance(inputs, dict):
return inputs
return dict(text=inputs)

@validator("metadata", pre=True)
def check_value_length(cls, metadata):
return _limit_metadata_values(metadata)

def __init__(self, *args, **kwargs):
"""Custom init to handle dynamic defaults"""
# noinspection PyArgumentList
super().__init__(*args, **kwargs)
self.status = self.status or (
"Default" if self.annotation is None else "Validated"
)


class TokenClassificationRecord(BaseModel):
class TokenClassificationRecord(BaseRecord):
"""Record for a token classification task

Args:
Expand Down Expand Up @@ -181,28 +240,9 @@ class TokenClassificationRecord(BaseModel):
List[Union[Tuple[str, int, int], Tuple[str, int, int, float]]]
] = None
annotation: Optional[List[Tuple[str, int, int]]] = None
prediction_agent: Optional[str] = None
annotation_agent: Optional[str] = None

id: Optional[Union[int, str]] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
status: Optional[str] = None
event_timestamp: Optional[datetime.datetime] = None
metrics: Optional[Dict[str, Any]] = None

@validator("metadata", pre=True)
def check_value_length(cls, metadata):
return _limit_metadata_values(metadata)

def __init__(self, *args, **kwargs):
"""Custom init to handle dynamic defaults"""
super().__init__(*args, **kwargs)
self.status = self.status or (
"Default" if self.annotation is None else "Validated"
)


class Text2TextRecord(BaseModel):
class Text2TextRecord(BaseRecord):
"""Record for a text to text task

Args:
Expand Down Expand Up @@ -242,14 +282,6 @@ class Text2TextRecord(BaseModel):

prediction: Optional[List[Union[str, Tuple[str, float]]]] = None
annotation: Optional[str] = None
prediction_agent: Optional[str] = None
annotation_agent: Optional[str] = None

id: Optional[Union[int, str]] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
status: Optional[str] = None
event_timestamp: Optional[datetime.datetime] = None
metrics: Optional[Dict[str, Any]] = None

@validator("prediction")
def prediction_as_tuples(
Expand All @@ -262,27 +294,5 @@ def prediction_as_tuples(
return prediction
return [(text, 1.0) for text in prediction]

@validator("metadata", pre=True)
def check_value_length(cls, metadata):
return _limit_metadata_values(metadata)

def __init__(self, *args, **kwargs):
"""Custom init to handle dynamic defaults"""
super().__init__(*args, **kwargs)
self.status = self.status or (
"Default" if self.annotation is None else "Validated"
)


def _limit_metadata_values(metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Checks metadata values length and apply value truncation for large values"""
new_value = limit_value_length(metadata, max_length=MAX_KEYWORD_LENGTH)
if new_value != metadata:
warnings.warn(
"Some metadata values exceed the max length. "
f"Those values will be truncated by keeping only the last {MAX_KEYWORD_LENGTH} characters."
)
return new_value


Record = Union[TextClassificationRecord, TokenClassificationRecord, Text2TextRecord]
20 changes: 6 additions & 14 deletions src/rubrix/server/tasks/token_classification/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,22 @@

from fastapi import Depends

from rubrix import MAX_KEYWORD_LENGTH
from rubrix.server.commons.es_helpers import (
aggregations,
sort_by2elasticsearch,
)
from rubrix._constants import MAX_KEYWORD_LENGTH
from rubrix.server.commons.es_helpers import aggregations, sort_by2elasticsearch
from rubrix.server.datasets.model import Dataset
from rubrix.server.tasks.commons import (
BulkResponse,
EsRecordDataFieldNames,
SortableField,
)
from rubrix.server.tasks.commons.dao import (
extends_index_properties,
)
from rubrix.server.tasks.commons.dao import extends_index_properties
from rubrix.server.tasks.commons.dao.dao import DatasetRecordsDAO, dataset_records_dao
from rubrix.server.tasks.commons.dao.model import RecordSearch
from rubrix.server.tasks.commons.metrics.service import MetricsService
from rubrix.server.tasks.token_classification.api.model import (
CreationTokenClassificationRecord,
MENTIONS_ES_FIELD_NAME,
PREDICTED_MENTIONS_ES_FIELD_NAME,
CreationTokenClassificationRecord,
TokenClassificationAggregations,
TokenClassificationQuery,
TokenClassificationRecord,
Expand Down Expand Up @@ -177,7 +172,7 @@ def search(
),
size=size,
record_from=record_from,
exclude_fields=["metrics"] if exclude_metrics else None
exclude_fields=["metrics"] if exclude_metrics else None,
)
return TokenClassificationSearchResults(
total=results.total,
Expand Down Expand Up @@ -239,8 +234,5 @@ def token_classification_service(
"""
global _instance
if not _instance:
_instance = TokenClassificationService(
dao=dao,
metrics=metrics
)
_instance = TokenClassificationService(dao=dao, metrics=metrics)
return _instance
21 changes: 18 additions & 3 deletions tests/client/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
import numpy
import pytest
from pydantic import ValidationError
from rubrix._constants import MAX_KEYWORD_LENGTH

from rubrix.client.models import Text2TextRecord, TextClassificationRecord
from rubrix.client.models import TokenClassificationRecord
from rubrix._constants import MAX_KEYWORD_LENGTH
from rubrix.client.models import (
BaseRecord,
Text2TextRecord,
TextClassificationRecord,
TokenClassificationRecord,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -92,3 +96,14 @@ def test_model_serialization_with_numpy_nan():
)

json_record = json.loads(record.json())


def test_warning_when_only_agent():
with pytest.warns(
UserWarning, match="`prediction_agent` will not be logged to the server."
):
BaseRecord(prediction_agent="mock")
with pytest.warns(
UserWarning, match="`annotation_agent` will not be logged to the server."
):
BaseRecord(annotation_agent="mock")