Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor header fields #42

Merged
merged 7 commits into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 26 additions & 110 deletions edfio/_header_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@
import datetime
import math
import re
from abc import ABC, abstractmethod
from collections.abc import Iterator
from typing import Any, Generic, TypeVar, overload

T = TypeVar("T", str, int, float, datetime.date, datetime.time)

_one_or_two_digits = "([ ]?\\d{1,2})"
_separator = "[.:'\\-\\/ ]"
Expand Down Expand Up @@ -41,10 +36,10 @@ def encode_int(value: int, length: int) -> bytes:
return encode_str(str(value), length)


def encode_float(value: float, length: int) -> bytes:
def encode_float(value: float) -> bytes:
if float(value).is_integer():
value = int(value)
return encode_str(str(value), length)
return encode_str(str(value), 8)


def decode_float(field: bytes) -> float:
Expand All @@ -54,112 +49,33 @@ def decode_float(field: bytes) -> float:
return value


class RawHeaderField(ABC, Generic[T]):
def __set_name__(self, owner: Any, name: str) -> None:
self.name = name
self.private_name = "_" + name

def __init__(self, length: int, *, is_settable: bool) -> None:
self.length = length
self.is_settable = is_settable

@overload
def __get__(self, instance: None, owner: Any) -> RawHeaderField[T]: ...

@overload
def __get__(self, instance: Any, owner: Any) -> T: ...

def __get__(self, instance: Any, owner: Any = None) -> RawHeaderField[T] | T:
if instance is None:
return self
return self.decode(getattr(instance, self.private_name))

def __set__(self, instance: Any, value: T) -> None:
if not self.is_settable:
raise AttributeError(f"can't set attribute {self.name}")
setattr(instance, self.private_name, self.encode(value))

@abstractmethod
def decode(self, field: bytes) -> T:
raise NotImplementedError

@abstractmethod
def encode(self, value: T) -> bytes:
raise NotImplementedError


class RawHeaderFieldStr(RawHeaderField[str]):
def __init__(self, length: int, *, is_settable: bool = False) -> None:
super().__init__(length, is_settable=is_settable)

def decode(self, field: bytes) -> str:
return decode_str(field)

def encode(self, value: str) -> bytes:
return encode_str(value, self.length)


class RawHeaderFieldInt(RawHeaderField[int]):
def __init__(self, length: int, *, is_settable: bool = False) -> None:
super().__init__(length, is_settable=is_settable)

def decode(self, field: bytes) -> int:
return int(decode_str(field))

def encode(self, value: int) -> bytes:
return encode_int(value, self.length)


class RawHeaderFieldFloat(RawHeaderField[float]):
def __init__(self, length: int, *, is_settable: bool = False) -> None:
super().__init__(length, is_settable=is_settable)

def decode(self, field: bytes) -> float:
return decode_float(field)

def encode(self, value: float) -> bytes:
return encode_float(value, self.length)


class RawHeaderFieldDate(RawHeaderField[datetime.date]):
def __init__(self, length: int, *, is_settable: bool = False) -> None:
super().__init__(length, is_settable=is_settable)

def decode(self, field: bytes) -> datetime.date:
date = decode_str(field)
match = DATE_OR_TIME_PATTERN.fullmatch(date)
if match is None:
raise ValueError(f"Invalid date for format DD.MM.YY: {date!r}")
day, month, year = (int(g) for g in match.groups())
if year >= 85: # noqa: PLR2004
year += 1900
else:
year += 2000
return datetime.date(year, month, day)

def encode(self, value: datetime.date) -> bytes:
if not 1985 <= value.year <= 2084: # noqa: PLR2004
raise ValueError("EDF only allows dates from 1985 to 2084")
return encode_str(value.strftime("%d.%m.%y"), self.length)
def decode_date(field: bytes) -> datetime.date:
date = decode_str(field)
match = DATE_OR_TIME_PATTERN.fullmatch(date)
if match is None:
raise ValueError(f"Invalid date for format DD.MM.YY: {date!r}")
day, month, year = (int(g) for g in match.groups())
if year >= 85: # noqa: PLR2004
year += 1900
else:
year += 2000
return datetime.date(year, month, day)


class RawHeaderFieldTime(RawHeaderField[datetime.time]):
def __init__(self, length: int, *, is_settable: bool = False) -> None:
super().__init__(length, is_settable=is_settable)
def encode_date(value: datetime.date) -> bytes:
if not 1985 <= value.year <= 2084: # noqa: PLR2004
raise ValueError("EDF only allows dates from 1985 to 2084")
return encode_str(value.strftime("%d.%m.%y"), 8)

def decode(self, field: bytes) -> datetime.time:
time = decode_str(field)
match = DATE_OR_TIME_PATTERN.fullmatch(time)
if match is None:
raise ValueError(f"Invalid time for format hh.mm.ss: {time!r}")
hours, minutes, seconds = (int(g) for g in match.groups())
return datetime.time(hours, minutes, seconds)

def encode(self, value: datetime.time) -> bytes:
return encode_str(value.isoformat().replace(":", "."), self.length)
def decode_time(field: bytes) -> datetime.time:
time = decode_str(field)
match = DATE_OR_TIME_PATTERN.fullmatch(time)
if match is None:
raise ValueError(f"Invalid time for format hh.mm.ss: {time!r}")
hours, minutes, seconds = (int(g) for g in match.groups())
return datetime.time(hours, minutes, seconds)


def get_header_fields(cls: type) -> Iterator[tuple[str, int]]:
for name, value in cls.__dict__.items():
if isinstance(value, RawHeaderField):
yield name, value.length
def encode_time(value: datetime.time) -> bytes:
return encode_str(value.isoformat().replace(":", "."), 8)
Loading