diff --git a/allennlp/common/mmap.py b/allennlp/common/mmap.py new file mode 100644 index 00000000000..ef81c756c24 --- /dev/null +++ b/allennlp/common/mmap.py @@ -0,0 +1,312 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from functools import lru_cache +import os +import shutil +import struct +import numpy as np +import torch +from allennlp.data.fields import DataArray + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float, + 7: np.double, + 8: np.uint16, +} + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return f"{prefix_path}.idx" + + +def data_file_path(prefix_path): + return f"{prefix_path}.bin" + + +def _warmup_mmap_file(path): + with open(path, "rb") as stream: + while stream.read(100 * 1024 * 1024): + pass + + +class MMapCacheReader: + class Index(object): + _HDR_MAGIC = b"MMIDIDX\x00\x00" + + @classmethod + def writer(cls, path, dtype): + class _Writer(object): + def __enter__(self): + self._file = open(path, "wb") + + self._file.write(cls._HDR_MAGIC) + self._file.write(struct.pack(" None: + self.cache_path = cache_path + self._builder = None + self._cache = None + + if os.path.exists(self.cache_path)): + if self.is_finalized(self.cache_path): + #scenario 2, we can read. + self._cache = MMapCacheReader(self.cache_path) + else: + #scenario 3, another training process is currently writing to it or was interrupted while it was writing. + pass + else: + self._builder = MMapCacheBuilder(self.cache_path) + #scenario 1, we need to write to it. + + def get_instances( + self, + data_path: str, + ) -> Optional[Iterable[Dict[str, DataArray]]]: + #dont need data_path here + assert self._cache + for i in range(len(self._cache)): + yield self._cache[i] + + + def set_instances( + self, + instances: Iterable[Dict[str, DataArray]], + ) -> Iterable[Dict[str, DataArray]]: + assert self._builder: + for instance in instances: + self._builder.add_instance(instance) + return instances + + + + def get_vocabulary(self) -> Optional[Vocabulary]: + pass + + def set_vocabulary(self, vocab: Vocabulary) -> None: + pass + + def finalize(self) -> None: + pass + + @classmethod + def hash_config(cls, config: Params) -> str: + pass + + @classmethod + def is_finalized(cls,path): + return True + + + + # Similar to the DatasetReader class, the Cache class will also have + # getters and setters for WorkerInfo and DistributedInfo. \ No newline at end of file