Skip to content

Commit

Permalink
Add types to disk storage
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Dec 1, 2022
1 parent 7f170be commit d23f09d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 16 deletions.
42 changes: 28 additions & 14 deletions openfisca_core/data_storage/on_disk_storage.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,69 @@
from __future__ import annotations

from typing import Any, Dict, KeysView, Optional, Type

import os
import shutil

import numpy

from openfisca_core import periods
from openfisca_core.indexed_enums import EnumArray
from openfisca_core import periods, indexed_enums as enums, types


class OnDiskStorage:
"""
Low-level class responsible for storing and retrieving calculated vectors on disk
"""

def __init__(self, storage_dir, is_eternal = False, preserve_storage_dir = False):
_files: Dict[types.Period, str]
_enums: Dict[str, Type[enums.Enum]]

def __init__(
self,
storage_dir: str,
is_eternal: bool = False,
preserve_storage_dir: bool = False,
) -> None:
self._files = {}
self._enums = {}
self.is_eternal = is_eternal
self.preserve_storage_dir = preserve_storage_dir
self.storage_dir = storage_dir

def _decode_file(self, file):
def _decode_file(self, file: str) -> Any:
enum: Optional[Type[enums.Enum]]
enum = self._enums.get(file)

if enum is not None:
return EnumArray(numpy.load(file), enum)
return enums.EnumArray(numpy.load(file), enum)
else:
return numpy.load(file)

def get(self, period):
def get(self, period: types.Period) -> Any:
if self.is_eternal:
period = periods.period(periods.ETERNITY)
period = periods.period(period)

values = self._files.get(period)
if values is None:
return None

return self._decode_file(values)

def put(self, value, period):
def put(self, value: Any, period: types.Period) -> None:
if self.is_eternal:
period = periods.period(periods.ETERNITY)
period = periods.period(period)

filename = str(period)
path = os.path.join(self.storage_dir, filename) + '.npy'
if isinstance(value, EnumArray):
if isinstance(value, enums.EnumArray):
self._enums[path] = value.possible_values
value = value.view(numpy.ndarray)
numpy.save(path, value)
self._files[period] = path

def delete(self, period = None):
def delete(self, period: Optional[types.Period] = None) -> None:
if period is None:
self._files = {}
return
Expand All @@ -65,21 +79,21 @@ def delete(self, period = None):
if not period.contains(period_item)
}

def get_known_periods(self):
def get_known_periods(self) -> KeysView[types.Period]:
return self._files.keys()

def restore(self):
self._files = files = {}
def restore(self) -> None:
self._files = {}
# Restore self._files from content of storage_dir.
for filename in os.listdir(self.storage_dir):
if not filename.endswith('.npy'):
continue
path = os.path.join(self.storage_dir, filename)
filename_core = filename.rsplit('.', 1)[0]
period = periods.period(filename_core)
files[period] = path
self._files[period] = path

def __del__(self):
def __del__(self) -> None:
if self.preserve_storage_dir:
return
shutil.rmtree(self.storage_dir) # Remove the holder temporary files
Expand Down
20 changes: 18 additions & 2 deletions openfisca_core/holders/holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,19 @@ class Holder:
A holder keeps tracks of a variable values after they have been calculated, or set as an input.
"""

def __init__(self, variable, population):
_disk_storage: types.Storage
_do_not_store: bool
_memory_storage: types.Storage
_on_disk_storable: bool
population: types.Population
simulation: types.Simulation
variable: types.Variable

def __init__(
self,
variable: types.Variable,
population: types.Population,
) -> None:
self.population = population
self.variable = variable
self.simulation = population.simulation
Expand Down Expand Up @@ -59,7 +71,11 @@ def clone(self, population):

return new

def create_disk_storage(self, directory = None, preserve = False):
def create_disk_storage(
self,
directory: Optional[str] = None,
preserve: bool = False,
) -> types.Storage:
if directory is None:
directory = self.simulation.data_storage_dir
storage_dir = os.path.join(directory, self.variable.name)
Expand Down
4 changes: 4 additions & 0 deletions openfisca_core/types/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,7 @@ def start(self) -> Any:
@abc.abstractmethod
def unit(self) -> Any:
"""Abstract method."""

@abc.abstractmethod
def contains(self, other: object) -> bool:
"""Abstract method."""

0 comments on commit d23f09d

Please sign in to comment.