diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index 2760e0824..1f6f46829 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -19,7 +19,6 @@ use crate::serializers::type_serializers; use crate::serializers::type_serializers::format::serialize_via_str; use crate::serializers::SerializationState; use crate::tools::{extract_int, py_err, safe_repr}; -use crate::url::{PyMultiHostUrl, PyUrl}; use super::config::InfNanMode; use super::errors::SERIALIZATION_ERR_MARKER; @@ -168,7 +167,13 @@ pub(crate) fn infer_to_python_known<'py>( let either_delta = EitherTimedelta::try_from(value)?; state.config.temporal_mode.timedelta_to_json(value.py(), either_delta)? } - ObType::Url | ObType::MultiHostUrl | ObType::Path => serialize_via_str(value, serialize_to_python())?, + ObType::Url + | ObType::MultiHostUrl + | ObType::Path + | ObType::Ipv4Address + | ObType::Ipv6Address + | ObType::Ipv4Network + | ObType::Ipv6Network => serialize_via_str(value, serialize_to_python())?, ObType::Uuid => { let uuid = super::type_serializers::uuid::uuid_to_string(value)?; uuid.into_py_any(py)? @@ -413,9 +418,13 @@ pub(crate) fn infer_serialize_known<'py, S: Serializer>( let either_delta = EitherTimedelta::try_from(value).map_err(py_err_se_err)?; state.config.temporal_mode.timedelta_serialize(either_delta, serializer) } - ObType::Url | ObType::MultiHostUrl | ObType::Path => { - serialize_via_str(value, serialize_to_json(serializer)).map_err(unwrap_ser_error) - } + ObType::Url + | ObType::MultiHostUrl + | ObType::Path + | ObType::Ipv4Address + | ObType::Ipv6Address + | ObType::Ipv4Network + | ObType::Ipv6Network => serialize_via_str(value, serialize_to_json(serializer)).map_err(unwrap_ser_error), ObType::PydanticSerializable => { call_pydantic_serializer(value, state, serialize_to_json(serializer)).map_err(unwrap_ser_error) } @@ -546,13 +555,15 @@ pub(crate) fn infer_json_key_known<'a, 'py>( let either_delta = EitherTimedelta::try_from(key)?; state.config.temporal_mode.timedelta_json_key(&either_delta) } - ObType::Url => { - let py_url: PyUrl = key.extract()?; - Ok(Cow::Owned(py_url.__str__(key.py()).to_string())) - } - ObType::MultiHostUrl => { - let py_url: PyMultiHostUrl = key.extract()?; - Ok(Cow::Owned(py_url.__str__(key.py()))) + ObType::Url + | ObType::MultiHostUrl + | ObType::Path + | ObType::Ipv4Address + | ObType::Ipv6Address + | ObType::Ipv4Network + | ObType::Ipv6Network => { + // FIXME it would be nice to have a "PyCow" which carries ownership of the Python type too + Ok(Cow::Owned(key.str()?.to_string_lossy().into_owned())) } ObType::Tuple => { let mut key_build = super::type_serializers::tuple::KeyBuilder::new(); @@ -574,10 +585,6 @@ pub(crate) fn infer_json_key_known<'a, 'py>( let k = key.getattr(intern!(key.py(), "value"))?; infer_json_key(&k, state).map(|cow| Cow::Owned(cow.into_owned())) } - ObType::Path => { - // FIXME it would be nice to have a "PyCow" which carries ownership of the Python type too - Ok(Cow::Owned(key.str()?.to_string_lossy().into_owned())) - } ObType::Complex => { let v = key.downcast::()?; Ok(type_serializers::complex::complex_to_str(v).into()) diff --git a/src/serializers/ob_type.rs b/src/serializers/ob_type.rs index 8d291ab3d..97d683efa 100644 --- a/src/serializers/ob_type.rs +++ b/src/serializers/ob_type.rs @@ -50,6 +50,11 @@ pub struct ObTypeLookup { uuid_object: Py, // `complex` builtin complex: usize, + // ip address types + ipv4_address: Py, + ipv6_address: Py, + ipv4_network: Py, + ipv6_network: Py, } static TYPE_LOOKUP: PyOnceLock = PyOnceLock::new(); @@ -89,6 +94,10 @@ impl ObTypeLookup { pattern_object: py.import("re").unwrap().getattr("Pattern").unwrap().unbind(), uuid_object: py.import("uuid").unwrap().getattr("UUID").unwrap().unbind(), complex: PyComplex::type_object_raw(py) as usize, + ipv4_address: py.import("ipaddress").unwrap().getattr("IPv4Address").unwrap().unbind(), + ipv6_address: py.import("ipaddress").unwrap().getattr("IPv6Address").unwrap().unbind(), + ipv4_network: py.import("ipaddress").unwrap().getattr("IPv4Network").unwrap().unbind(), + ipv6_network: py.import("ipaddress").unwrap().getattr("IPv6Network").unwrap().unbind(), } } @@ -159,6 +168,10 @@ impl ObTypeLookup { ObType::Pattern => self.pattern_object.as_ptr() as usize == ob_type, ObType::Uuid => self.uuid_object.as_ptr() as usize == ob_type, ObType::Complex => self.complex == ob_type, + ObType::Ipv4Address => self.ipv4_address.as_ptr() as usize == ob_type, + ObType::Ipv6Address => self.ipv6_address.as_ptr() as usize == ob_type, + ObType::Ipv4Network => self.ipv4_network.as_ptr() as usize == ob_type, + ObType::Ipv6Network => self.ipv6_network.as_ptr() as usize == ob_type, ObType::Unknown => false, }; @@ -254,6 +267,10 @@ impl ObTypeLookup { ObType::Path } else if ob_type == self.pattern_object.as_ptr() as usize { ObType::Pattern + } else if ob_type == self.ipv4_address.as_ptr() as usize { + ObType::Ipv4Address + } else if ob_type == self.ipv6_address.as_ptr() as usize { + ObType::Ipv6Address } else { // this allows for subtypes of the supported class types, // if `ob_type` didn't match any member of self, we try again with the next base type pointer @@ -334,6 +351,16 @@ impl ObTypeLookup { ObType::Path } else if value.is_instance(self.pattern_object.bind(py)).unwrap_or(false) { ObType::Pattern + } else if value.is_instance_of::() { + ObType::Complex + } else if value.is_instance(self.ipv4_address.bind(py)).unwrap_or(false) { + ObType::Ipv4Address + } else if value.is_instance(self.ipv6_address.bind(py)).unwrap_or(false) { + ObType::Ipv6Address + } else if value.is_instance(self.ipv4_network.bind(py)).unwrap_or(false) { + ObType::Ipv4Network + } else if value.is_instance(self.ipv6_network.bind(py)).unwrap_or(false) { + ObType::Ipv6Network } else { ObType::Unknown } @@ -417,6 +444,11 @@ pub enum ObType { Uuid, // complex builtin Complex, + // ip address types + Ipv4Address, + Ipv6Address, + Ipv4Network, + Ipv6Network, // unknown type Unknown, } diff --git a/tests/serializers/test_any.py b/tests/serializers/test_any.py index f17628381..085b2b94b 100644 --- a/tests/serializers/test_any.py +++ b/tests/serializers/test_any.py @@ -1,4 +1,5 @@ import dataclasses +import ipaddress import json import platform import re @@ -722,3 +723,56 @@ class MyEnum(Enum): assert v.to_json({MyEnum.A: 'x'}) == b'{"1":"x"}' assert v.to_python(1) == 1 assert v.to_json(1) == b'1' + + +class SubIpV4(ipaddress.IPv4Address): + def __str__(self): + return super().__str__() + '_subclassed' + + +class SubIpV6(ipaddress.IPv6Address): + def __str__(self): + return super().__str__() + '_subclassed' + + +class SubNetV4(ipaddress.IPv4Network): + def __str__(self): + return super().__str__() + '_subclassed' + + +class SubNetV6(ipaddress.IPv6Network): + def __str__(self): + return super().__str__() + '_subclassed' + + +class SubInterfaceV4(ipaddress.IPv4Interface): + def __str__(self): + return super().__str__() + '_subclassed' + + +class SubInterfaceV6(ipaddress.IPv6Interface): + def __str__(self): + return super().__str__() + '_subclassed' + + +@pytest.mark.parametrize( + ('value', 'expected_json'), + [ + (ipaddress.IPv4Address('192.168.1.1'), '192.168.1.1'), + (ipaddress.IPv6Address('2001:0db8:85a3:0000:0000:8a2e:0370:7334'), '2001:db8:85a3::8a2e:370:7334'), + (SubIpV4('192.168.1.1'), '192.168.1.1_subclassed'), + (SubIpV6('2001:0db8:85a3:0000:0000:8a2e:0370:7334'), '2001:db8:85a3::8a2e:370:7334_subclassed'), + (ipaddress.IPv4Network('192.168.1.0/24'), '192.168.1.0/24'), + (ipaddress.IPv6Network('2001:0db8:85a3:0000:0000:8a2e:0370:7334'), '2001:db8:85a3::8a2e:370:7334/128'), + (SubNetV4('192.168.1.0/24'), '192.168.1.0/24_subclassed'), + (SubNetV6('2001:0db8:85a3:0000:0000:8a2e:0370:7334'), '2001:db8:85a3::8a2e:370:7334/128_subclassed'), + (ipaddress.IPv4Interface('192.168.1.1/24'), '192.168.1.1/24'), + (ipaddress.IPv6Interface('2001:0db8:85a3:0000:0000:8a2e:0370:7334'), '2001:db8:85a3::8a2e:370:7334/128'), + (SubInterfaceV4('192.168.1.1/24'), '192.168.1.1/24_subclassed'), + (SubInterfaceV6('2001:0db8:85a3:0000:0000:8a2e:0370:7334'), '2001:db8:85a3::8a2e:370:7334/128_subclassed'), + ], +) +def test_ipaddress_type_inference(any_serializer, value, expected_json): + assert any_serializer.to_python(value) == value + assert any_serializer.to_python(value, mode='json') == expected_json + assert any_serializer.to_json(value) == f'"{expected_json}"'.encode()