Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TzInfo equality check based on offset #1197

Merged
merged 9 commits into from
Feb 20, 2024
15 changes: 13 additions & 2 deletions src/input/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,19 @@ impl TzInfo {
hasher.finish()
}

fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool {
op.matches(self.seconds.cmp(&other.seconds))
fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> PyResult<Py<PyAny>> {
let py = other.py();
if other.is_instance_of::<PyTzInfo>() {
let offset_delta = other.call_method1(intern!(py, "utcoffset"), (py.None(),))?;
if offset_delta.is_none() {
return Ok(py.NotImplemented());
}
let offset_seconds: f64 = offset_delta.call_method0(intern!(py, "total_seconds"))?.extract()?;
davidhewitt marked this conversation as resolved.
Show resolved Hide resolved
let offset = offset_seconds.round() as i32;
Ok(op.matches(self.seconds.cmp(&offset)).into_py(py))
} else {
Ok(py.NotImplemented())
}
}

fn __deepcopy__(&self, py: Python, _memo: &PyDict) -> PyResult<Py<Self>> {
Expand Down
15 changes: 15 additions & 0 deletions tests/test_tzinfo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import copy
import functools
import pickle
import sys
import unittest
from datetime import datetime, timedelta, timezone, tzinfo

from pydantic_core import SchemaValidator, TzInfo, core_schema

if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo


class _ALWAYS_EQ:
"""
Expand Down Expand Up @@ -80,6 +84,7 @@ class TestTzInfo(unittest.TestCase):
def setUp(self):
self.ACDT = TzInfo(timedelta(hours=9.5).total_seconds())
self.EST = TzInfo(-timedelta(hours=5).total_seconds())
self.UTC = TzInfo(timedelta(0).total_seconds())
self.DT = datetime(2010, 1, 1)

def test_str(self):
Expand Down Expand Up @@ -163,6 +168,16 @@ def test_comparison(self):
self.assertFalse(tz <= SMALLEST)
self.assertTrue(tz >= SMALLEST)

# offset based comparion tests for tzinfo derived classes like datetime.timezone.
utcdatetime = self.DT.replace(tzinfo=timezone.utc)
self.assertTrue(tz == utcdatetime.tzinfo)
estdatetime = self.DT.replace(tzinfo=timezone(-timedelta(hours=5)))
self.assertTrue(self.EST == estdatetime.tzinfo)
self.assertTrue(tz > estdatetime.tzinfo)
if sys.version_info >= (3, 9) and sys.platform == 'linux':
with self.assertRaises(TypeError):
tz > ZoneInfo('Europe/London')
davidhewitt marked this conversation as resolved.
Show resolved Hide resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add tests to confirm that comparisons with zoneinfo.ZoneInfo("Europe/London") or similar don't succeed.

Copy link
Contributor Author

@13sin 13sin Feb 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test added for zoneinfo.ZoneInfo would check whether it return false. As, NotImplemented would not throw error. there would be identity check eventually as __eq__ is not implemented for right-side and left-side. Would this be right way of doing this? Other option would be to throw NotImplementedError.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think by returning NotImplemented then equality will return False and ordering operators will raise a TypeError. I see you added an equality test, maybe add a test that ordering throws TypeError and then we're good here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thank you so much @davidhewitt .

def test_copy(self):
for tz in self.ACDT, self.EST:
tz_copy = copy.copy(tz)
Expand Down