Skip to content

Commit

Permalink
add fallback for calculation of num_data_records to avoid floating po…
Browse files Browse the repository at this point in the history
…int errors (#3)
  • Loading branch information
hofaflo authored Nov 16, 2023
1 parent 80fedaa commit 46a8e7b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
- When `EdfSignal.physical_min` or `EdfSignal.physical_max` do not fit into their header fields, they are now always rounded down or up, respectively, to ensure all physical values lie within the physical range ([#2](https://github.com/the-siesta-group/edfio/pull/2)).
- Support non-standard header fields (not encoded as UTF-8) by replacing incompatible characters with "�" ([#4](https://github.com/the-siesta-group/edfio/pull/4)).

### Fixed
- The calculation of `num_data_records` from signal duration and `data_record_duration` is now more robust to floating point errors ([#3](https://github.com/the-siesta-group/edfio/pull/3))


## [0.1.0] - 2023-11-09

Expand Down
14 changes: 8 additions & 6 deletions edfio/edf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from decimal import Decimal
from fractions import Fraction
from functools import singledispatch
from pathlib import Path
Expand Down Expand Up @@ -1327,12 +1328,13 @@ def _calculate_num_data_records(
raise ValueError(
f"data_record_duration must be positive, got {data_record_duration}"
)
required_num_data_records = signal_duration / data_record_duration
if required_num_data_records != int(required_num_data_records):
raise ValueError(
f"Signal duration of {signal_duration}s is not exactly divisible by data_record_duration of {data_record_duration}s"
)
return int(required_num_data_records)
for f in (lambda x: x, lambda x: Decimal(str(x))):
required_num_data_records = f(signal_duration) / f(data_record_duration)
if required_num_data_records == int(required_num_data_records):
return int(required_num_data_records)
raise ValueError(
f"Signal duration of {signal_duration}s is not exactly divisible by data_record_duration of {data_record_duration}s"
)


def _calculate_data_record_duration(signals: Sequence[EdfSignal]) -> float:
Expand Down
25 changes: 24 additions & 1 deletion tests/test_edf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
read_edf,
)
from edfio._utils import FloatRange, IntRange
from edfio.edf import _create_annotations_signal
from edfio.edf import _calculate_num_data_records, _create_annotations_signal
from tests import TEST_DATA_DIR

EDF_FILE = TEST_DATA_DIR / "short_psg.edf"
Expand Down Expand Up @@ -957,3 +957,26 @@ def test_rounding_of_physical_range_does_not_produce_clipping_or_integer_overflo
np.testing.assert_allclose(sig.data, data, atol=1e-11)
# round((0.0000014999 - 0.0) / 0.000002 * 65535 + (-32768)) = 16380
assert sig._digital.tolist() == [-32768, 16380]


@pytest.mark.parametrize(
("signal_duration", "data_record_duration", "result"),
[
(11, 1, 11),
(1.1, 0.1, 11),
(1.01, 0.01, 101),
(1.001, 0.001, 1001),
(1.0001, 0.0001, 10001),
(1.00001, 0.00001, 100001),
(1.000001, 0.000001, 1000001),
(1.0000001, 0.0000001, 10000001),
(1, 1 / 3, 3),
(1, 1 / 333333, 333333),
],
)
def test_calculate_num_data_records(
signal_duration: float,
data_record_duration: float,
result: int,
):
assert _calculate_num_data_records(signal_duration, data_record_duration) == result

0 comments on commit 46a8e7b

Please sign in to comment.