Skip to content

Commit

Permalink
Merge pull request #41 from the-siesta-group/read-edf-with-invalid-fi…
Browse files Browse the repository at this point in the history
…lesize

Read edf files with where filesize does not match the header information
  • Loading branch information
marcoross authored Mar 29, 2024
2 parents 3537cb1 + aaca193 commit 9a0a1fa
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Added
- Make EDF+ header fields `patient` and `recording` more tolerant regarding non-compliant inputs: omitted subfields are returned as `X` instead of throwing an exception ([#18](https://github.com/the-siesta-group/edfio/pull/18)).
- Allow reading from `tempfile.SpooledTemporaryFile[bytes]` ([#36](https://github.com/the-siesta-group/edfio/pull/36)).
- Enable the reading of EDF files where the filesize does not match the header information. Incomplete data records will be truncated ([#41](https://github.com/the-siesta-group/edfio/pull/41)).

### Fixed
- Allow reading more non-standard starttime and startdate fields ([#30](https://github.com/the-siesta-group/edfio/pull/30)).
Expand Down
25 changes: 23 additions & 2 deletions edfio/edf.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,37 @@ def __repr__(self) -> str:
def _load_data(self, file: Path | io.BufferedReader | io.BytesIO) -> None:
lens = [signal.samples_per_data_record for signal in self._signals]
datarecord_len = sum(lens)
truncated = False
if not isinstance(file, Path):
datarecords = np.frombuffer(file.read(), dtype=np.int16)
data_bytes = file.read()
actual_records = len(data_bytes) // (datarecord_len * 2)
if actual_records * datarecord_len * 2 < len(data_bytes):
truncated = True
datarecords = np.frombuffer(
data_bytes, dtype=np.int16, count=actual_records * datarecord_len
)
datarecords.shape = (actual_records, datarecord_len)
else:
remaining_bytes = file.stat().st_size - self.bytes_in_header_record
actual_records = remaining_bytes // (datarecord_len * 2)
if actual_records * datarecord_len * 2 < remaining_bytes:
truncated = True
datarecords = np.memmap(
file,
dtype=np.int16,
mode="r",
offset=self.bytes_in_header_record,
shape=(actual_records, datarecord_len),
)
if truncated:
warnings.warn(
"Incomplete data record at the end of the EDF file. Data was truncated."
)
if self.num_data_records not in (-1, actual_records):
warnings.warn(
f"EDF header indicates {self.num_data_records} data records, but file contains {actual_records} records. Updating header."
)
datarecords.shape = (self.num_data_records, datarecord_len)
self._num_data_records = Edf.num_data_records.encode(actual_records)
ends = np.cumsum(lens)
starts = ends - lens

Expand Down
37 changes: 37 additions & 0 deletions tests/test_edf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import tempfile
from pathlib import Path
from shutil import copyfile
from typing import Literal

import numpy as np
Expand Down Expand Up @@ -1160,3 +1161,39 @@ def test_sampling_frequencies_leading_to_floating_point_issues_in_signal_duratio
assert edf.num_data_records == 10
assert edf.signals[0].samples_per_data_record == 22
assert edf.signals[1].samples_per_data_record == 9


# fmt: off
@pytest.mark.parametrize(
("extra_bytes", "num_records_in_header", "expected_warning"),
[
# extra bytes num records field expected warning
(1, 10, "Incomplete data record at the end of the EDF file"),
(15, 11, "Incomplete data record at the end of the EDF file"),
(0, 9, "EDF header indicates 9 data records, but file contains 10 records. Updating header."),
(0, 11, "EDF header indicates 11 data records, but file contains 10 records. Updating header."),
],
)
# fmt: on
def test_read_edf_with_invalid_number_of_records(
tmp_path: Path,
extra_bytes: int,
num_records_in_header: int,
expected_warning: str,
):
invalid_edf = tmp_path / "invalid.edf"
copyfile(EDF_FILE, invalid_edf)
with invalid_edf.open("rb+") as file:
file.seek(236)
file.write(f"{num_records_in_header: <8}".encode("ascii"))
if extra_bytes > 0:
file.seek(0, 2)
file.write(b"\0" * extra_bytes)

for io_obj in (invalid_edf, invalid_edf.read_bytes()):
with pytest.warns(UserWarning, match=expected_warning):
edf = read_edf(io_obj)
assert edf.num_data_records == 10
comparison_edf = read_edf(EDF_FILE)
for signal, comparison_signal in zip(edf.signals, comparison_edf.signals):
np.testing.assert_array_equal(signal.data, comparison_signal.data)

0 comments on commit 9a0a1fa

Please sign in to comment.