From 5e109cc8873e044a0900eaf2e8b1d074cd3f7aac Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Fri, 18 Oct 2024 00:38:16 -0700 Subject: [PATCH] Use a context-manager when opening files (#10895) * Use a context-manager when opening files Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa * Apply isort and black reformatting Signed-off-by: artbataev --------- Signed-off-by: Alexandros Koumparoulis Signed-off-by: akoumpa Signed-off-by: artbataev Co-authored-by: akoumpa Co-authored-by: artbataev --- .../language_modeling/text_memmap_dataset.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py b/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py index 4882708f698f..dc4fb8ececc5 100644 --- a/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py @@ -260,7 +260,8 @@ def load_file(self, fn, index_mapping_dir: Optional[str] = None): raise RuntimeError(f"Missing header, expected {self._header_lines} header lines") # load meta info - idx_info_dict = pickle.load(open(idx_fn + ".info", "rb")) + with open(idx_fn + ".info", "rb") as fp: + idx_info_dict = pickle.load(fp) # test for mismatch in expected newline_int if "newline_int" in idx_info_dict: newline_int = idx_info_dict["newline_int"] @@ -378,9 +379,7 @@ def __init__( self._data_sep = data_sep def _build_data_from_text(self, text: str): - """ - - """ + """ """ _build_data_from_text = super()._build_data_from_text data = {} text_fields = text.split(self._data_sep) @@ -513,7 +512,11 @@ def _build_memmap_index_files(newline_int, build_index_fn, fn, index_mapping_dir def build_index_files( - dataset_paths, newline_int, workers=None, build_index_fn=_build_index_from_memdata, index_mapping_dir: str = None, + dataset_paths, + newline_int, + workers=None, + build_index_fn=_build_index_from_memdata, + index_mapping_dir: str = None, ): """Auxiliary method to build multiple index files""" if len(dataset_paths) < 1: @@ -528,7 +531,12 @@ def build_index_files( ctx = mp.get_context("fork") with ctx.Pool(workers) as p: build_status = p.map( - partial(_build_memmap_index_files, newline_int, build_index_fn, index_mapping_dir=index_mapping_dir,), + partial( + _build_memmap_index_files, + newline_int, + build_index_fn, + index_mapping_dir=index_mapping_dir, + ), dataset_paths, )