Skip to content

Commit

Permalink
Fix chrono::Duration tests and extend them for DurationSecondsWithFrac
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasbb committed May 15, 2020
1 parent 6f3fc6c commit 1fb4552
Show file tree
Hide file tree
Showing 2 changed files with 455 additions and 45 deletions.
242 changes: 218 additions & 24 deletions src/chrono.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
//! [chrono]: https://docs.rs/chrono/

use crate::{
de::DeserializeAs, ser::SerializeAs, utils, DurationSeconds, DurationSecondsWithFrac, Format,
Integer, Strictness,
de::DeserializeAs, ser::SerializeAs, utils, DurationSeconds, DurationSecondsWithFrac, Flexible,
Format, Integer, Strict, Strictness,
};
use chrono_crate::{DateTime, Duration, NaiveDateTime, Utc};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde::{
de::{Error, Unexpected, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
};
use std::fmt;
use utils::NANOS_PER_SEC;

/// Deserialize a Unix timestamp with optional subsecond precision into a `DateTime<Utc>`.
Expand Down Expand Up @@ -180,14 +184,38 @@ impl<'de> DeserializeAs<'de, NaiveDateTime> for DateTime<Utc> {
}
}

fn duration_subsec_nanos(dur: &Duration) -> u32 {
(*dur - Duration::seconds(dur.num_seconds()))
fn duration_subsec_nanos(mut dur: Duration) -> u32 {
if dur < Duration::zero() {
dur = Duration::zero() - dur;
}
(dur - Duration::seconds(dur.num_seconds()))
.num_nanoseconds()
.unwrap() as u32
}

fn duration_as_secs_f64(dur: &Duration) -> f64 {
(dur.num_seconds() as f64) + (duration_subsec_nanos(dur) as f64) / (NANOS_PER_SEC as f64)
fn duration_as_secs_f64(dur: Duration) -> f64 {
let mut secs = dur.num_seconds();
let subsecs = duration_subsec_nanos(dur);

// Properly round the value
if dur < Duration::zero() && subsecs > 0 {
secs -= 1;
}

(secs as f64) + (subsecs as f64) / (NANOS_PER_SEC as f64)
}

#[test]
fn test_duration_as_secs_f64() {
assert_eq!(duration_as_secs_f64(Duration::seconds(1)), 1.);
assert_eq!(
duration_as_secs_f64(Duration::nanoseconds(500_000_000)),
0.5
);
assert_eq!(
duration_as_secs_f64(Duration::nanoseconds(-500_000_000)),
-0.5
);
}

impl<STRICTNESS> SerializeAs<Duration> for DurationSeconds<Integer, STRICTNESS>
Expand All @@ -199,9 +227,15 @@ where
S: Serializer,
{
let mut secs = source.num_seconds();
let subsecs = duration_subsec_nanos(*source);

// Properly round the value
if duration_subsec_nanos(source) >= 500_000_000 {
secs += 1;
if subsecs >= 500_000_000 {
if *source < Duration::zero() {
secs -= 1;
} else {
secs += 1;
}
}
secs.serialize(serializer)
}
Expand All @@ -215,7 +249,7 @@ where
where
S: Serializer,
{
duration_as_secs_f64(source).round().serialize(serializer)
duration_as_secs_f64(*source).round().serialize(serializer)
}
}

Expand All @@ -228,9 +262,15 @@ where
S: Serializer,
{
let mut secs = source.num_seconds();
let subsecs = duration_subsec_nanos(*source);

// Properly round the value
if duration_subsec_nanos(source) >= 500_000_000 {
secs += 1;
if subsecs >= 500_000_000 {
if *source < Duration::zero() {
secs -= 1;
} else {
secs += 1;
}
}
secs.to_string().serialize(serializer)
}
Expand All @@ -244,7 +284,7 @@ where
where
S: Serializer,
{
duration_as_secs_f64(source).serialize(serializer)
duration_as_secs_f64(*source).serialize(serializer)
}
}

Expand All @@ -256,38 +296,192 @@ where
where
S: Serializer,
{
duration_as_secs_f64(source)
duration_as_secs_f64(*source)
.to_string()
.serialize(serializer)
}
}
fn duration_from_secs_f64(secs: f64) -> Result<Duration, String> {
const MAX_NANOS_F64: f64 = ((i64::max_value() as i128 + 1) * (NANOS_PER_SEC as i128)) as f64;
// TODO why are the seconds converted to nanoseconds first?
// Does it make sense to just truncate the value?
let nanos = secs * (NANOS_PER_SEC as f64);
if !nanos.is_finite() {
return Err("got non-finite value when converting float to duration".into());
}
if nanos >= MAX_NANOS_F64 || nanos < -MAX_NANOS_F64 {
return Err("overflow when converting float to duration".into());
}
let nanos = nanos as i128;
let secs = Duration::seconds((nanos / (NANOS_PER_SEC as i128)) as i64);
let subsec = Duration::nanoseconds((nanos % (NANOS_PER_SEC as i128)) as i64);
Ok(secs + subsec)
}

struct DurationVisitiorFlexible;
impl<'de> Visitor<'de> for DurationVisitiorFlexible {
type Value = Duration;

fn expecting(&self, formatter: &mut fmt::Formatter) -> ::std::fmt::Result {
formatter.write_str("an integer, a float, or a string containing a number")
}

fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
where
E: Error,
{
Ok(Duration::seconds(value))
}

fn visit_u64<E>(self, secs: u64) -> Result<Self::Value, E>
where
E: Error,
{
if secs <= i64::max_value() as u64 {
Ok(Duration::seconds(secs as i64))
} else {
Err(Error::custom(format!(
"Seconds larger than {} are not supported for chrono::Duration. Found {}",
i64::max_value(),
secs,
)))
}
}

fn visit_f64<E>(self, secs: f64) -> Result<Self::Value, E>
where
E: Error,
{
duration_from_secs_f64(secs).map_err(Error::custom)
}

fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: Error,
{
let parts: Vec<_> = value.split('.').collect();

match *parts.as_slice() {
[seconds] => {
if let Ok(seconds) = i64::from_str_radix(seconds, 10) {
Ok(Duration::seconds(seconds))
} else {
Err(Error::invalid_value(Unexpected::Str(value), &self))
}
}
[seconds, subseconds] => {
if let Ok(mut seconds) = i64::from_str_radix(seconds, 10) {
let subseclen = subseconds.chars().count() as u32;
if subseclen > 9 {
return Err(Error::custom(format!(
"Duration only support nanosecond precision but '{}' has more than 9 digits.",
value
)));
}

if let Ok(mut subseconds) = u32::from_str_radix(subseconds, 10) {
// convert subseconds to nanoseconds (10^-9), require 9 places for nanoseconds
subseconds *= 10u32.pow(9 - subseclen);

// Check if first char of seconds part is negative sign
if parts[0].starts_with('-') {
seconds -= 1;
}

Ok(Duration::seconds(seconds)
+ Duration::nanoseconds(i64::from(subseconds)))
} else {
Err(Error::invalid_value(Unexpected::Str(value), &self))
}
} else {
Err(Error::invalid_value(Unexpected::Str(value), &self))
}
}

_ => Err(Error::invalid_value(Unexpected::Str(value), &self)),
}
}
}

impl<'de> DeserializeAs<'de, Duration> for DurationSeconds<Integer, Strict> {
fn deserialize_as<D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
i64::deserialize(deserializer).map(Duration::seconds)
}
}

impl<'de, FORMAT, S> DeserializeAs<'de, Duration> for DurationSeconds<FORMAT, S>
impl<'de> DeserializeAs<'de, Duration> for DurationSeconds<f64, Strict> {
fn deserialize_as<D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let val = f64::deserialize(deserializer)?;
duration_from_secs_f64(val).map_err(Error::custom)
}
}

impl<'de> DeserializeAs<'de, Duration> for DurationSeconds<String, Strict> {
fn deserialize_as<D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
crate::rust::display_fromstr::deserialize(deserializer).map(Duration::seconds)
}
}

impl<'de, FORMAT> DeserializeAs<'de, Duration> for DurationSeconds<FORMAT, Flexible>
where
FORMAT: Format,
S: Strictness,
{
fn deserialize_as<D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
// deserializer.deserialize_any(DurationVisitiorFlexible)
i32::deserialize(deserializer);
Ok(Duration::zero())
deserializer.deserialize_any(DurationVisitiorFlexible)
}
}

impl<'de> DeserializeAs<'de, Duration> for DurationSecondsWithFrac<Integer, Strict> {
fn deserialize_as<D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
i64::deserialize(deserializer).map(Duration::seconds)
}
}

impl<'de> DeserializeAs<'de, Duration> for DurationSecondsWithFrac<f64, Strict> {
fn deserialize_as<D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let val = f64::deserialize(deserializer)?;
duration_from_secs_f64(val).map_err(Error::custom)
}
}

impl<'de> DeserializeAs<'de, Duration> for DurationSecondsWithFrac<String, Strict> {
fn deserialize_as<D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let dur = String::deserialize(deserializer)?;
DurationVisitiorFlexible.visit_str(&*dur)
// crate::rust::display_fromstr::deserialize(deserializer)
// .and_then(|val| duration_from_secs_f64(val).map_err(Error::custom))
}
}

impl<'de, FORMAT, S> DeserializeAs<'de, Duration> for DurationSecondsWithFrac<FORMAT, S>
impl<'de, FORMAT> DeserializeAs<'de, Duration> for DurationSecondsWithFrac<FORMAT, Flexible>
where
FORMAT: Format,
S: Strictness,
{
fn deserialize_as<D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
// deserializer.deserialize_any(DurationVisitiorFlexible)
i32::deserialize(deserializer);
Ok(Duration::zero())
deserializer.deserialize_any(DurationVisitiorFlexible)
}
}
Loading

0 comments on commit 1fb4552

Please sign in to comment.