Skip to content

Commit

Permalink
skip invalid numbers when logging a metric or parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
jkasiraj committed May 5, 2022
1 parent dea989b commit d29b49a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/smexperiments/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import logging
import botocore
import json
from math import isnan, isinf
from numbers import Number
from smexperiments._utils import get_module
from os.path import join

Expand Down Expand Up @@ -231,7 +233,8 @@ def log_parameter(self, name, value):
name (str): The name of the parameter
value (str or numbers.Number): The value of the parameter
"""
self.trial_component.parameters[name] = value
if self._is_input_valid("parameter", name, value):
self.trial_component.parameters[name] = value

def log_parameters(self, parameters):
"""Record a collection of parameter values for this trial component.
Expand All @@ -245,7 +248,10 @@ def log_parameters(self, parameters):
Args:
parameters (dict[str, str or numbers.Number]): The parameters to record.
"""
self.trial_component.parameters.update(parameters)
filtered_parameters = {
key: value for (key, value) in parameters.items() if self._is_input_valid("parameter", key, value)
}
self.trial_component.parameters.update(filtered_parameters)

def log_input(self, name, value, media_type=None):
"""Record a single input artifact for this trial component.
Expand Down Expand Up @@ -402,7 +408,8 @@ def log_metric(self, metric_name, value, timestamp=None, iteration_number=None):
AttributeError: If the metrics writer is not initialized.
"""
try:
self._metrics_writer.log_metric(metric_name, value, timestamp, iteration_number)
if self._is_input_valid("metric", metric_name, value):
self._metrics_writer.log_metric(metric_name, value, timestamp, iteration_number)
except AttributeError:
if not self._metrics_writer:
if not self._warned_on_metrics:
Expand Down Expand Up @@ -654,6 +661,12 @@ def _log_graph_artifact(self, name, data, graph_type, output_artifact):
else:
self._lineage_artifact_tracker.add_input_artifact(artifact_name, s3_uri, etag, graph_type)

def _is_input_valid(self, input_type, field_name, field_value):
if isinstance(field_value, Number) and (isnan(field_value) or isinf(field_value)):
logging.warning(f"Failed to log {input_type} {field_name}. Received invalid value: {field_value}.")
return False
return True

def __enter__(self):
"""Updates the start time of the tracked trial component.
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import tempfile
import os
import datetime
from math import nan, inf
import numpy as np
from smexperiments import api_types, tracker, trial_component, _utils, _environment
import pandas as pd

Expand Down Expand Up @@ -171,6 +173,11 @@ def test_log_parameter(under_test):
assert under_test.trial_component.parameters["whizz"] == 1


def test_log_parameter_skip_invalid_value(under_test):
under_test.log_parameter("key", nan)
assert "key" not in under_test.trial_component.parameters


def test_enter(under_test):
under_test.__enter__()
assert isinstance(under_test.trial_component.start_time, datetime.datetime)
Expand Down Expand Up @@ -213,6 +220,11 @@ def test_log_parameters(under_test):
assert under_test.trial_component.parameters == {"a": "b", "c": "d", "e": 5}


def test_log_parameters_skip_invalid_values(under_test):
under_test.log_parameters({"a": "b", "c": "d", "e": 5, "f": nan})
assert under_test.trial_component.parameters == {"a": "b", "c": "d", "e": 5}


def test_log_input(under_test):
under_test.log_input("foo", "baz", "text/text")
assert under_test.trial_component.input_artifacts == {
Expand All @@ -233,6 +245,11 @@ def test_log_metric(under_test):
under_test._metrics_writer.log_metric.assert_called_with("foo", 1.0, 1, now)


def test_log_metric_skip_invalid_value(under_test):
under_test.log_metric(None, nan, None, None)
assert not under_test._metrics_writer.log_metric.called


def test_log_metric_attribute_error(under_test):
now = datetime.datetime.now()

Expand Down Expand Up @@ -630,3 +647,19 @@ def test_log_roc_curve(under_test):
)

under_test._lineage_artifact_tracker.add_input_artifact("TestROCCurve", "s3uri_value", "etag_value", "ROCCurve")


@pytest.mark.parametrize(
"metric_value",
[1.3, "nan", "inf", "-inf", None],
)
def test_is_input_valid(under_test, metric_value):
assert under_test._is_input_valid("metric", "Name", metric_value)


@pytest.mark.parametrize(
"metric_value",
[nan, inf, -inf],
)
def test__is_input_valid_false(under_test, metric_value):
assert not under_test._is_input_valid("parameter", "Name", metric_value)

0 comments on commit d29b49a

Please sign in to comment.