From 8ed576f59d8b810b43c3284b6001a092ef9312d2 Mon Sep 17 00:00:00 2001 From: Tomi Valkeinen Date: Mon, 30 Sep 2024 18:30:01 +0300 Subject: [PATCH] py: Fix contextmanager issues This might need another look... --- py/rwmem/mappedregisterfile.py | 2 ++ py/rwmem/registerfile.py | 14 ++++++++++---- py/tests/test_mmap_regs.py | 12 ++++++------ py/tests/test_registerfile.py | 8 ++++---- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/py/rwmem/mappedregisterfile.py b/py/rwmem/mappedregisterfile.py index cad73fa..70de57c 100644 --- a/py/rwmem/mappedregisterfile.py +++ b/py/rwmem/mappedregisterfile.py @@ -151,6 +151,8 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_tb): self._map.close() + del self._regblock + self._registers.clear() def __getitem__(self, key: str): if key not in self._registers: diff --git a/py/rwmem/registerfile.py b/py/rwmem/registerfile.py index 4a408d1..bd80e44 100644 --- a/py/rwmem/registerfile.py +++ b/py/rwmem/registerfile.py @@ -209,13 +209,17 @@ def __init__(self, source: str | bytes | BinaryIO) -> None: self._map = mmap.mmap(self.fd, 0, mmap.MAP_SHARED, access=mmap.ACCESS_COPY) finally: os.close(self.fd) + + self._mmap = self._map elif isinstance(source, bytes): # XXX ctypes requires a writeable buffer... self._map = bytearray(source) + self._mmap = None else: self.fd = source.fileno() # ctypes requires a writeable mmap, so we use ACCESS_COPY self._map = mmap.mmap(self.fd, 0, mmap.MAP_SHARED, access=mmap.ACCESS_COPY) + self._mmap = self._map self.rfd = RegisterFileData.from_buffer(self._map) @@ -243,15 +247,17 @@ def __init__(self, source: str | bytes | BinaryIO) -> None: self._regblock_infos: dict[str, RegisterBlock | None] = dict.fromkeys(rb_names) def close(self): - if self.fd: - self._map.close() + if self._mmap: + self._mmap.close() def __enter__(self): return self def __exit__(self, exc_type, exc_value, exc_tb): - if self.fd: - self._map.close() + del self.rfd + self._regblock_infos.clear() + if self._mmap: + self._mmap.close() @property def num_blocks(self) -> int: diff --git a/py/tests/test_mmap_regs.py b/py/tests/test_mmap_regs.py index e170d24..bf2d0ac 100755 --- a/py/tests/test_mmap_regs.py +++ b/py/tests/test_mmap_regs.py @@ -10,12 +10,12 @@ REGS_PATH = os.path.dirname(os.path.abspath(__file__)) + '/test.regs' BIN_PATH = os.path.dirname(os.path.abspath(__file__)) + '/test.bin' -#class ContextManagerTests(unittest.TestCase): -# def test(self): -# with rw.RegisterFile(REGS_PATH) as rf: -# with rw.MappedRegisterBlock(BIN_PATH, rf['BLOCK1'], -# mode=rw.MapMode.Read) as map: -# self.assertEqual(map['REG1'].value, 0xf00dbaad) +class ContextManagerTests(unittest.TestCase): + def test(self): + with rw.RegisterFile(REGS_PATH) as rf: + with rw.MappedRegisterBlock(BIN_PATH, rf['BLOCK1'], + mode=rw.MapMode.Read) as map: + self.assertEqual(map['REG1'].value, 0xf00dbaad) class MmapRegsTests(unittest.TestCase): def setUp(self): diff --git a/py/tests/test_registerfile.py b/py/tests/test_registerfile.py index 76a49ce..4f8ec62 100644 --- a/py/tests/test_registerfile.py +++ b/py/tests/test_registerfile.py @@ -6,10 +6,10 @@ REGS_PATH = os.path.dirname(os.path.abspath(__file__)) + '/test.regs' -#class ContextManagerTests(unittest.TestCase): -# def test(self): -# with rw.RegisterFile(REGS_PATH) as rf: -# pass +class ContextManagerTests(unittest.TestCase): + def test(self): + with rw.RegisterFile(REGS_PATH) as rf: + pass class MmapRegsTests(unittest.TestCase): def setUp(self):