From 3d2902f40e9faaf2e9292f0998502fa93e14d6f7 Mon Sep 17 00:00:00 2001 From: hofaflo Date: Tue, 14 Nov 2023 14:33:41 +0100 Subject: [PATCH] add fallback for calculation of num_data_records to avoid floating point errors --- CHANGELOG.md | 3 +++ edfio/edf.py | 14 ++++++++------ tests/test_edf.py | 25 ++++++++++++++++++++++++- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 51516fc..518384e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ ### Changed - When `EdfSignal.physical_min` or `EdfSignal.physical_max` do not fit into their header fields, they are now always rounded down or up, respectively, to ensure all physical values lie within the physical range ([#2](https://github.com/the-siesta-group/edfio/pull/2)). +### Fixed +- The calculation of `num_data_records` from signal duration and `data_record_duration` is now more robust to floating point errors ([#3](https://github.com/the-siesta-group/edfio/pull/3)) + ## [0.1.0] - 2023-11-09 diff --git a/edfio/edf.py b/edfio/edf.py index 787a1f7..e8f89b4 100644 --- a/edfio/edf.py +++ b/edfio/edf.py @@ -9,6 +9,7 @@ import warnings from collections.abc import Iterable, Sequence from dataclasses import dataclass +from decimal import Decimal from fractions import Fraction from functools import singledispatch from pathlib import Path @@ -1327,12 +1328,13 @@ def _calculate_num_data_records( raise ValueError( f"data_record_duration must be positive, got {data_record_duration}" ) - required_num_data_records = signal_duration / data_record_duration - if required_num_data_records != int(required_num_data_records): - raise ValueError( - f"Signal duration of {signal_duration}s is not exactly divisible by data_record_duration of {data_record_duration}s" - ) - return int(required_num_data_records) + for f in (lambda x: x, lambda x: Decimal(str(x))): + required_num_data_records = f(signal_duration) / f(data_record_duration) + if required_num_data_records == int(required_num_data_records): + return int(required_num_data_records) + raise ValueError( + f"Signal duration of {signal_duration}s is not exactly divisible by data_record_duration of {data_record_duration}s" + ) def _calculate_data_record_duration(signals: Sequence[EdfSignal]) -> float: diff --git a/tests/test_edf.py b/tests/test_edf.py index 668a933..49e6cc3 100644 --- a/tests/test_edf.py +++ b/tests/test_edf.py @@ -19,7 +19,7 @@ read_edf, ) from edfio._utils import FloatRange, IntRange -from edfio.edf import _create_annotations_signal +from edfio.edf import _calculate_num_data_records, _create_annotations_signal from tests import TEST_DATA_DIR EDF_FILE = TEST_DATA_DIR / "short_psg.edf" @@ -957,3 +957,26 @@ def test_rounding_of_physical_range_does_not_produce_clipping_or_integer_overflo np.testing.assert_allclose(sig.data, data, atol=1e-11) # round((0.0000014999 - 0.0) / 0.000002 * 65535 + (-32768)) = 16380 assert sig._digital.tolist() == [-32768, 16380] + + +@pytest.mark.parametrize( + ("signal_duration", "data_record_duration", "result"), + [ + (11, 1, 11), + (1.1, 0.1, 11), + (1.01, 0.01, 101), + (1.001, 0.001, 1001), + (1.0001, 0.0001, 10001), + (1.00001, 0.00001, 100001), + (1.000001, 0.000001, 1000001), + (1.0000001, 0.0000001, 10000001), + (1, 1 / 3, 3), + (1, 1 / 333333, 333333), + ], +) +def test_calculate_num_data_records( + signal_duration: float, + data_record_duration: float, + result: int, +): + assert _calculate_num_data_records(signal_duration, data_record_duration) == result