From 1ca635788446f4aae6ab930798b7558785bc49f1 Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Thu, 10 Oct 2024 13:23:14 -0400 Subject: [PATCH] disallow non-finite Box coords --- rastervision_core/rastervision/core/box.py | 4 ++++ tests/core/test_box.py | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/rastervision_core/rastervision/core/box.py b/rastervision_core/rastervision/core/box.py index 164459fa7..3521a5716 100644 --- a/rastervision_core/rastervision/core/box.py +++ b/rastervision_core/rastervision/core/box.py @@ -40,6 +40,10 @@ def __init__(self, ymin: int, xmin: int, ymax: int, xmax: int): ymax: maximum y value xmax: maximum x value """ + if not all(math.isfinite(v) for v in (ymin, xmin, ymax, xmax)): + raise ValueError( + f'Invalid Box coordinates: {(ymin, xmin, ymax, xmax)}.') + self.ymin = ymin self.xmin = xmin self.ymax = ymax diff --git a/tests/core/test_box.py b/tests/core/test_box.py index b3b55d023..e13c24207 100644 --- a/tests/core/test_box.py +++ b/tests/core/test_box.py @@ -443,6 +443,11 @@ def test_in(self): with self.assertRaises(NotImplementedError): _ = '' in Box(0, 0, 1, 1) + def test_error_on_nonfinite_inputs(self): + self.assertRaises(ValueError, lambda: Box(np.inf, 0, 0, 0)) + self.assertRaises(ValueError, lambda: Box(-np.inf, 0, 0, 0)) + self.assertRaises(ValueError, lambda: Box(np.nan, 0, 0, 0)) + if __name__ == '__main__': unittest.main()