From f8a3d7f5f43e05f5852bab2436dcabdb8f557a51 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 4 Jan 2023 11:49:05 +0100 Subject: [PATCH] [Tune] Fix best trial in ProgressReporter with nan (#31276) If a trial reports a metric value of `np.nan` or similar (which is not vanilla Python `None`), any comparison made between that value and any other value will return false, thus leading to the `nan` value being considered the best value for the purpose of determining best trial in `TuneReporterBase` if it was reported first. ``` Trial 1 -> np.nan Trial 2 -> 1 1 > np.nan == False -> np.nan is the best trial ``` To guard against that issue, we use `pd.isnull` check which is more general than the `is None` check used previously. Signed-off-by: Antoni Baum --- python/ray/tune/progress_reporter.py | 5 ++-- .../ray/tune/tests/test_progress_reporter.py | 28 +++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/python/ray/tune/progress_reporter.py b/python/ray/tune/progress_reporter.py index 6ebc9ebd108d9..f032953ee955f 100644 --- a/python/ray/tune/progress_reporter.py +++ b/python/ray/tune/progress_reporter.py @@ -12,9 +12,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -from ray._private.dict import flatten_dict +import pandas as pd import ray +from ray._private.dict import flatten_dict from ray.tune.callback import Callback from ray.tune.logger import pretty_print from ray.tune.result import ( @@ -432,7 +433,7 @@ def _current_best_trial(self, trials: List[Trial]): if not t.last_result: continue metric_value = unflattened_lookup(metric, t.last_result, default=None) - if metric_value is None: + if pd.isnull(metric_value): continue if not best_trial or metric_value * metric_op > best_metric: best_metric = metric_value * metric_op diff --git a/python/ray/tune/tests/test_progress_reporter.py b/python/ray/tune/tests/test_progress_reporter.py index 90ad04f3416e5..61ae854b56d3f 100644 --- a/python/ray/tune/tests/test_progress_reporter.py +++ b/python/ray/tune/tests/test_progress_reporter.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest +import numpy as np from ray import tune from ray._private.test_utils import run_string_as_driver @@ -513,6 +514,33 @@ def testBestTrialZero(self): best_trial, metric = reporter._current_best_trial([trial1, trial2, trial3]) assert best_trial == trial2 + def testBestTrialNan(self): + trial1 = Trial("", config={}, stub=True) + trial1.last_result = {"metric": np.nan, "config": {}} + + trial2 = Trial("", config={}, stub=True) + trial2.last_result = {"metric": 0, "config": {}} + + trial3 = Trial("", config={}, stub=True) + trial3.last_result = {"metric": 2, "config": {}} + + reporter = TuneReporterBase(metric="metric", mode="min") + best_trial, metric = reporter._current_best_trial([trial1, trial2, trial3]) + assert best_trial == trial2 + + trial1 = Trial("", config={}, stub=True) + trial1.last_result = {"metric": np.nan, "config": {}} + + trial2 = Trial("", config={}, stub=True) + trial2.last_result = {"metric": 0, "config": {}} + + trial3 = Trial("", config={}, stub=True) + trial3.last_result = {"metric": 2, "config": {}} + + reporter = TuneReporterBase(metric="metric", mode="max") + best_trial, metric = reporter._current_best_trial([trial1, trial2, trial3]) + assert best_trial == trial3 + def testTimeElapsed(self): # Sun Feb 7 14:18:40 2016 -0800 # (time of the first Ray commit)