Skip to content

Commit

Permalink
Refac: Unit tests (#34)
Browse files Browse the repository at this point in the history
* add test for dataset_profiles
* refactor delete profiles
* add tests + refactors for monitor_helpers
* made dataset and org ids optional
* refactor dataclasses init
  • Loading branch information
murilommen committed Jul 28, 2023
1 parent 1837099 commit f365071
Show file tree
Hide file tree
Showing 10 changed files with 336 additions and 97 deletions.
77 changes: 77 additions & 0 deletions tests/helpers/test_dataset_profiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
from datetime import datetime

import pytest

from whylabs_toolkit.helpers.dataset_profiles import (
delete_all_profiles_for_period,
validate_timestamp_in_millis,
process_date_input
)

def test_validate_timestamp_in_millis() -> None:
assert validate_timestamp_in_millis(1627233600000) == True
assert validate_timestamp_in_millis(-1231214) == False
assert validate_timestamp_in_millis("some_string") == False
assert validate_timestamp_in_millis(None) == False
assert validate_timestamp_in_millis(3.1415) == False

def test_process_date_input() -> None:
input_milliseconds = 1627233600000
assert process_date_input(input_milliseconds) == input_milliseconds

input_datetime = datetime(2023, 7, 25)
expected_milliseconds = int(input_datetime.timestamp() * 1000.0)
assert process_date_input(input_datetime) == expected_milliseconds

with pytest.raises(ValueError):
process_date_input("invalid")

with pytest.raises(ValueError):
process_date_input(-12498127412)


## -- Note:
# After calling delete_dataset_profiles, it will schedule the deletion,
# that currently happens hourly, so there is no trivial way to check that on
# unit tests. For that matter, we will only make the assertion of a successful call,
# and the actual deletion logic is tested and maintained by Songbird only

def test_delete_profile_for_datetime_range():
result = delete_all_profiles_for_period(
start=datetime(2023,7,5),
end=datetime(2023,7,6),
dataset_id = os.environ["DATASET_ID"],
org_id=os.environ["ORG_ID"]
)

assert result.get("id") == f"{os.environ['ORG_ID']}/{os.environ['DATASET_ID']}"


def test_delete_profiles_for_milliseconds_range():
result = delete_all_profiles_for_period(
start=int(datetime(2023,7,5).timestamp()*1000.0),
end=int(datetime(2023,7,6).timestamp()*1000.0),
dataset_id = os.environ["DATASET_ID"],
org_id=os.environ["ORG_ID"]
)

assert result.get("id") == f"{os.environ['ORG_ID']}/{os.environ['DATASET_ID']}"


def test_delete_profiles_raises_if_other_format_is_passed():
with pytest.raises(ValueError):
delete_all_profiles_for_period(
start=-123123123123,
end=int(datetime(2023,7,6).timestamp()*1000.0),
dataset_id = os.environ["DATASET_ID"],
org_id=os.environ["ORG_ID"]
)
with pytest.raises(ValueError):
delete_all_profiles_for_period(
start="string_example",
end=int(datetime(2023,7,6).timestamp()*1000.0),
dataset_id = os.environ["DATASET_ID"],
org_id=os.environ["ORG_ID"]
)

2 changes: 2 additions & 0 deletions tests/helpers/test_entity_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def test_change_columns_discreteness() -> None:
assert update_discreteness.current_entity_schema["columns"]["prediction_temperature"][
"discreteness"] == "continuous"


def test_same_column_on_both_parameters_should_raise():
columns = ColumnsDiscreteness(
discrete=["temperature"],
Expand Down Expand Up @@ -107,6 +108,7 @@ def test_same_column_on_both_parameters_should_raise():
with pytest.raises(ValueError):
update_entity.update()


def test_change_columns_schema():
columns_schema = {"temperature": ColumnDataType.boolean}

Expand Down
99 changes: 91 additions & 8 deletions tests/helpers/test_monitor_helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
from typing import List, Dict

from whylabs_toolkit.helpers.monitor_helpers import (
delete_monitor,
get_model_granularity,
get_monitor_config
get_monitor_config,
get_analyzer_ids,
get_monitor
)
from whylabs_toolkit.helpers.utils import get_monitor_api
from whylabs_toolkit.utils.granularity import Granularity
Expand Down Expand Up @@ -46,6 +49,7 @@
}
}


class BaseTestMonitor:
@classmethod
def setup_class(cls) -> None:
Expand Down Expand Up @@ -78,17 +82,96 @@ class TestDeleteMonitor(BaseTestMonitor):
def teardown_class(cls) -> None:
pass

def test_get_analyzer_id(self) -> None:
pass

def test_get_analyzer_ids(self) -> None:
pass
analyzer_ids = get_analyzer_ids(
org_id=ORG_ID,
dataset_id = DATASET_ID,
monitor_id= MONITOR_ID,
)
assert analyzer_ids is not None
assert isinstance(analyzer_ids, List)
for analyzer in analyzer_ids:
assert analyzer == f"{MONITOR_ID}-analyzer"

def test_get_analyzer_ids_that_dont_exist(self) -> None:
analyzer_ids = get_analyzer_ids(
org_id=ORG_ID,
dataset_id = DATASET_ID,
monitor_id= "dont_exist",
)
assert analyzer_ids is None

analyzer_ids = get_analyzer_ids(
org_id="wrong_org",
dataset_id = DATASET_ID,
monitor_id= MONITOR_ID,
)

assert analyzer_ids is None

analyzer_ids = get_analyzer_ids(
org_id=ORG_ID,
dataset_id = "model-X",
monitor_id= MONITOR_ID,
)

assert analyzer_ids is None


def test_get_monitor_config(self) -> None:
pass

monitor_config = get_monitor_config(
org_id=ORG_ID,
dataset_id = DATASET_ID,
)

assert monitor_config is not None
assert isinstance(monitor_config, Dict)
for key in monitor_config.keys():
assert key in ['orgId', 'datasetId', 'granularity', 'metadata', 'allowPartialTargetBatches', 'analyzers', 'monitors']

def test_get_monitor_config_not_existing_dataset_id(self, caplog) -> None:
with caplog.at_level("WARNING"):
monitor_config = get_monitor_config(
org_id=ORG_ID,
dataset_id = "fake-dataset-id",
)

assert monitor_config is None
assert "Could not find a monitor config for fake-dataset-id" in caplog.text

def test_get_monitor(self) -> None:
pass
monitor = get_monitor(
monitor_id=MONITOR_ID,
dataset_id=DATASET_ID,
org_id=ORG_ID
)

assert monitor is not None
assert isinstance(monitor, Dict)

for key in monitor.keys():
assert key in ['id', 'analyzerIds', 'schedule', 'mode', 'disabled', 'actions', 'metadata']


def test_get_monitor_with_wrong_configs(self, caplog) -> None:
with caplog.at_level("WARNING"):
monitor = get_monitor(
monitor_id="fake-monitor",
dataset_id=DATASET_ID,
org_id=ORG_ID
)
assert monitor is None
assert f"Could not find a monitor with id fake-monitor for {DATASET_ID}." in caplog.text
with caplog.at_level("WARNING"):
monitor = get_monitor(
monitor_id=MONITOR_ID,
dataset_id="fake-dataset-id",
org_id=ORG_ID
)

assert monitor is None
assert f"Could not find a monitor with id {MONITOR_ID} for fake-dataset-id." in caplog.text


def test_get_granularity(self) -> None:
granularity = get_model_granularity(org_id=ORG_ID, dataset_id=DATASET_ID)
Expand Down
3 changes: 0 additions & 3 deletions tests/monitor/manager/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def test_monitor_running_eagerly(self, existing_monitor_setup: MonitorSetup):

assert new_expected_result["allowPartialTargetBatches"] == False




class TestNotificationActions(TestCase):
def setUp(self) -> None:
Expand All @@ -103,7 +101,6 @@ def setUp(self) -> None:
notifications_api=self.notifications_api,
monitor_api=self.monitor_api
)


def test_notification_actions_are_updated(self) -> None:
self.monitor_manager._update_notification_actions()
Expand Down
31 changes: 28 additions & 3 deletions tests/monitor/manager/test_monitor_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,30 @@ def test_set_fixed_dates_baseline(monitor_setup: MonitorSetup) -> None:
end=datetime(2023,1,2, tzinfo=timezone.utc)
)
)

monitor_setup.apply()

assert monitor_setup.config.baseline == TimeRangeBaseline(
range=TimeRange(
start=datetime(2023,1,1, tzinfo=timezone.utc),
end=datetime(2023,1,2, tzinfo=timezone.utc)
)
)

def test_exclude_target_columns(monitor_setup):
monitor_setup.exclude_target_columns(
columns=["prediction_temperature"]
)

assert monitor_setup._exclude_columns == ["prediction_temperature"]

monitor_setup.apply()

assert isinstance(monitor_setup.target_matrix, ColumnMatrix)
assert monitor_setup.target_matrix.exclude == ["prediction_temperature"]

assert isinstance(monitor_setup.analyzer.targetMatrix, ColumnMatrix)
assert monitor_setup.analyzer.targetMatrix.exclude == ["prediction_temperature"]


def test_set_target_columns(monitor_setup):
Expand All @@ -38,9 +55,15 @@ def test_set_target_columns(monitor_setup):
)

assert monitor_setup._target_columns == ["prediction_temperature"]

monitor_setup.apply()

assert isinstance(monitor_setup.target_matrix, ColumnMatrix)
assert monitor_setup.target_matrix.include == ["prediction_temperature"]
assert isinstance(monitor_setup.analyzer.targetMatrix, ColumnMatrix)
assert monitor_setup.analyzer.targetMatrix.include == ["prediction_temperature"]


def test_setup(monitor_setup):
def test_setup_apply(monitor_setup):
assert not monitor_setup.monitor
assert not monitor_setup.analyzer

Expand Down Expand Up @@ -154,7 +177,9 @@ def test_apply_wont_change_monitor_columns(monitor_setup):
monitor_setup.apply()

assert monitor_setup.analyzer.targetMatrix != ColumnMatrix(include=["*"] , exclude=[], segments=[])


assert monitor_setup.target_matrix == ColumnMatrix(include=["prediction_temperature", "temperature"] , exclude=[], segments=[])
assert monitor_setup.analyzer.targetMatrix == ColumnMatrix(include=["prediction_temperature", "temperature"] , exclude=[], segments=[])

def test_apply_wont_erase_existing_preconfig(monitor_setup):
monitor_setup.config = FixedThresholdsConfig(
Expand Down
40 changes: 33 additions & 7 deletions whylabs_toolkit/helpers/dataset_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,49 @@
from whylabs_client.api.dataset_profile_api import DeleteDatasetProfilesResponse

from whylabs_toolkit.helpers.utils import get_dataset_profile_api
from whylabs_toolkit.helpers.config import Config

date_or_millis = Union[datetime, int]


# TODO test and make sure it's working
def validate_timestamp_in_millis(epoch_milliseconds: int) -> bool:
if not isinstance(epoch_milliseconds, int):
return False
try:
epoch_seconds = epoch_milliseconds / 1000
dt = datetime.fromtimestamp(epoch_seconds)
return dt >= datetime(1970, 1, 1)
except (ValueError, OverflowError):
return False


def process_date_input(date_input: date_or_millis) -> int:
if isinstance(date_input, int):
try:
assert validate_timestamp_in_millis(epoch_milliseconds=date_input)
return date_input
except AssertionError:
raise ValueError("You must provide a valid date input")
elif isinstance(date_input, datetime):
return int(date_input.timestamp() * 1000.0)
else:
raise ValueError(f"The date object {date_input} input must be a datetime or an integer Epoch!")


def delete_all_profiles_for_period(
start: date_or_millis,
end: date_or_millis,
dataset_id: str,
org_id: Optional[str],
) -> None:
config: Config = Config(),
org_id: Optional[str] = None,
dataset_id: Optional[str] = None,
) -> DeleteDatasetProfilesResponse:
api = get_dataset_profile_api()

profile_start_timestamp = start if isinstance(start, int) else int(start.timestamp() * 1000.0)
profile_end_timestamp = end if isinstance(end, int) else int(end.timestamp() * 1000.0)
profile_start_timestamp = process_date_input(date_input=start)
profile_end_timestamp = process_date_input(date_input=end)

org_id = org_id or config.get_default_org_id()
dataset_id = dataset_id or config.get_default_dataset_id()

result: DeleteDatasetProfilesResponse = api.delete_dataset_profiles(
org_id=org_id,
Expand All @@ -30,4 +56,4 @@ def delete_all_profiles_for_period(
profile_end_timestamp=profile_end_timestamp,
)

print(result)
return result
12 changes: 8 additions & 4 deletions whylabs_toolkit/helpers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def update_model_metadata(
dataset_id: str,
dataset_id: Optional[str] = None,
org_id: Optional[str] = None,
time_period: Optional[str] = None,
model_type: Optional[str] = None,
Expand All @@ -22,6 +22,9 @@ def update_model_metadata(
"""
Update model attributes like model type and period.
"""
org_id = org_id or config.get_default_org_id()
dataset_id = dataset_id or config.get_default_dataset_id()

api = get_models_api(config=config)

model_metadata = api.get_model(org_id=org_id, model_id=dataset_id)
Expand All @@ -41,16 +44,17 @@ def update_model_metadata(


def add_custom_metric(
dataset_id: str,
label: str,
column: str,
default_metric: str,
org_id: Optional[str] = None,
dataset_id: Optional[str] = None,
config: Config = Config(),
) -> None:

if not org_id:
org_id = config.get_default_org_id()
org_id = org_id or config.get_default_org_id()
dataset_id = dataset_id or config.get_default_dataset_id()

api = get_models_api(config=config)
metric_schema = MetricSchema(label=label, column=column, default_metric=default_metric)

Expand Down
Loading

0 comments on commit f365071

Please sign in to comment.