Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fallback for calculation of num_data_records to avoid floating point errors #3

Merged
merged 1 commit into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
### Changed
- When `EdfSignal.physical_min` or `EdfSignal.physical_max` do not fit into their header fields, they are now always rounded down or up, respectively, to ensure all physical values lie within the physical range ([#2](https://github.com/the-siesta-group/edfio/pull/2)).

### Fixed
- The calculation of `num_data_records` from signal duration and `data_record_duration` is now more robust to floating point errors ([#3](https://github.com/the-siesta-group/edfio/pull/3))


## [0.1.0] - 2023-11-09

Expand Down
14 changes: 8 additions & 6 deletions edfio/edf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from decimal import Decimal
from fractions import Fraction
from functools import singledispatch
from pathlib import Path
Expand Down Expand Up @@ -1327,12 +1328,13 @@ def _calculate_num_data_records(
raise ValueError(
f"data_record_duration must be positive, got {data_record_duration}"
)
required_num_data_records = signal_duration / data_record_duration
if required_num_data_records != int(required_num_data_records):
raise ValueError(
f"Signal duration of {signal_duration}s is not exactly divisible by data_record_duration of {data_record_duration}s"
)
return int(required_num_data_records)
for f in (lambda x: x, lambda x: Decimal(str(x))):
marcoross marked this conversation as resolved.
Show resolved Hide resolved
required_num_data_records = f(signal_duration) / f(data_record_duration)
if required_num_data_records == int(required_num_data_records):
return int(required_num_data_records)
raise ValueError(
f"Signal duration of {signal_duration}s is not exactly divisible by data_record_duration of {data_record_duration}s"
)


def _calculate_data_record_duration(signals: Sequence[EdfSignal]) -> float:
Expand Down
25 changes: 24 additions & 1 deletion tests/test_edf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
read_edf,
)
from edfio._utils import FloatRange, IntRange
from edfio.edf import _create_annotations_signal
from edfio.edf import _calculate_num_data_records, _create_annotations_signal
from tests import TEST_DATA_DIR

EDF_FILE = TEST_DATA_DIR / "short_psg.edf"
Expand Down Expand Up @@ -957,3 +957,26 @@ def test_rounding_of_physical_range_does_not_produce_clipping_or_integer_overflo
np.testing.assert_allclose(sig.data, data, atol=1e-11)
# round((0.0000014999 - 0.0) / 0.000002 * 65535 + (-32768)) = 16380
assert sig._digital.tolist() == [-32768, 16380]


@pytest.mark.parametrize(
("signal_duration", "data_record_duration", "result"),
[
(11, 1, 11),
(1.1, 0.1, 11),
(1.01, 0.01, 101),
(1.001, 0.001, 1001),
(1.0001, 0.0001, 10001),
(1.00001, 0.00001, 100001),
(1.000001, 0.000001, 1000001),
(1.0000001, 0.0000001, 10000001),
(1, 1 / 3, 3),
(1, 1 / 333333, 333333),
],
)
def test_calculate_num_data_records(
signal_duration: float,
data_record_duration: float,
result: int,
):
assert _calculate_num_data_records(signal_duration, data_record_duration) == result