diff --git a/src/input/datetime.rs b/src/input/datetime.rs index 0a3cdf929..ebb5675f2 100644 --- a/src/input/datetime.rs +++ b/src/input/datetime.rs @@ -563,8 +563,19 @@ impl TzInfo { hasher.finish() } - fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { - op.matches(self.seconds.cmp(&other.seconds)) + fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> PyResult> { + let py = other.py(); + if other.is_instance_of::() { + let offset_delta = other.call_method1(intern!(py, "utcoffset"), (py.None(),))?; + if offset_delta.is_none() { + return Ok(py.NotImplemented()); + } + let offset_seconds: f64 = offset_delta.call_method0(intern!(py, "total_seconds"))?.extract()?; + let offset = offset_seconds.round() as i32; + Ok(op.matches(self.seconds.cmp(&offset)).into_py(py)) + } else { + Ok(py.NotImplemented()) + } } fn __deepcopy__(&self, py: Python, _memo: &PyDict) -> PyResult> { diff --git a/tests/test_tzinfo.py b/tests/test_tzinfo.py index cb67b737e..949c9175d 100644 --- a/tests/test_tzinfo.py +++ b/tests/test_tzinfo.py @@ -1,11 +1,15 @@ import copy import functools import pickle +import sys import unittest from datetime import datetime, timedelta, timezone, tzinfo from pydantic_core import SchemaValidator, TzInfo, core_schema +if sys.version_info >= (3, 9): + from zoneinfo import ZoneInfo + class _ALWAYS_EQ: """ @@ -80,6 +84,7 @@ class TestTzInfo(unittest.TestCase): def setUp(self): self.ACDT = TzInfo(timedelta(hours=9.5).total_seconds()) self.EST = TzInfo(-timedelta(hours=5).total_seconds()) + self.UTC = TzInfo(timedelta(0).total_seconds()) self.DT = datetime(2010, 1, 1) def test_str(self): @@ -163,6 +168,17 @@ def test_comparison(self): self.assertFalse(tz <= SMALLEST) self.assertTrue(tz >= SMALLEST) + # offset based comparion tests for tzinfo derived classes like datetime.timezone. + utcdatetime = self.DT.replace(tzinfo=timezone.utc) + self.assertTrue(tz == utcdatetime.tzinfo) + estdatetime = self.DT.replace(tzinfo=timezone(-timedelta(hours=5))) + self.assertTrue(self.EST == estdatetime.tzinfo) + self.assertTrue(tz > estdatetime.tzinfo) + if sys.version_info >= (3, 9) and sys.platform == 'linux': + self.assertFalse(tz == ZoneInfo('Europe/London')) + with self.assertRaises(TypeError): + tz > ZoneInfo('Europe/London') + def test_copy(self): for tz in self.ACDT, self.EST: tz_copy = copy.copy(tz)