From 30c9b7ff57c45b32a724905b16df227f924e5d70 Mon Sep 17 00:00:00 2001 From: xpai Date: Thu, 26 Oct 2023 20:22:12 +0800 Subject: [PATCH] Update H5DataBlockLoader to support multiprocessing --- .../dataloaders/h5_block_dataloader.py | 66 +++++++++++-------- fuxictr/version.py | 2 +- setup.py | 2 +- 3 files changed, 40 insertions(+), 30 deletions(-) diff --git a/fuxictr/pytorch/dataloaders/h5_block_dataloader.py b/fuxictr/pytorch/dataloaders/h5_block_dataloader.py index 03ce37b..56e847a 100644 --- a/fuxictr/pytorch/dataloaders/h5_block_dataloader.py +++ b/fuxictr/pytorch/dataloaders/h5_block_dataloader.py @@ -20,20 +20,18 @@ import h5py from itertools import chain import torch +from torch.utils import data import logging import glob -class DataLoader(object): - def __init__(self, feature_map, data_block_list, batch_size=32, shuffle=False, verbose=0, **kwargs): - # data_block_list: path list of data blocks +class DataBlockDataset(data.IterableDataset): + def __init__(self, feature_map, block_file_list, shuffle=False, verbose=0): + # block_file_list: path list of data blocks self.feature_map = feature_map - self.data_blocks = data_block_list + self.data_blocks = block_file_list self.shuffle = shuffle - self.batch_size = batch_size self.verbose = verbose - self.num_blocks = len(self.data_blocks) - self.num_batches, self.num_samples = self.count_batches_and_samples() def load_data_array(self, data_path): data_dict = load_h5(data_path, verbose=self.verbose) @@ -54,14 +52,37 @@ def iter_block(self, data_block): indexes = list(range(block_size)) if self.shuffle: np.random.shuffle(indexes) - for idx in range(0, block_size, self.batch_size): - batch_index = indexes[idx:(idx + self.batch_size)] - yield darray[batch_index, :] + for idx in indexes: + yield darray[idx, :] def __iter__(self): - if self.shuffle: - np.random.shuffle(self.data_blocks) - return chain.from_iterable(map(self.iter_block, self.data_blocks)) + worker_info = data.get_worker_info() + if worker_info is None: # single-process data loading + chunk_list = self.data_blocks + else: # in a worker process + worker_id = worker_info.id + chunk_list = np.array_split(self.data_blocks, worker_info.num_workers) + sub_list = chunk_list[worker_id].tolist() + if self.shuffle: + np.random.shuffle(sub_list) + return chain.from_iterable(map(self.iter_block, sub_list)) + + +class DataLoader(data.DataLoader): + def __init__(self, feature_map, data_path, batch_size=32, shuffle=False, + num_workers=1, verbose=0, **kwargs): + data_blocks = glob.glob(data_path + "/*.h5") + assert len(data_blocks) > 0, f"invalid data_path: {data_path}" + if len(data_blocks) > 1: + data_blocks.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) # e.g. "part_1.h5" + self.data_blocks = data_blocks + self.num_blocks = len(self.data_blocks) + self.feature_map = feature_map + self.batch_size = batch_size + self.num_batches, self.num_samples = self.count_batches_and_samples() + self.dataset = DataBlockDataset(feature_map, data_blocks, shuffle=shuffle, verbose=verbose) + super(DataLoader, self).__init__(dataset=self.dataset, batch_size=batch_size, + num_workers=num_workers) def __len__(self): return self.num_batches @@ -86,26 +107,15 @@ def __init__(self, feature_map, stage="both", train_data=None, valid_data=None, test_gen = None self.stage = stage if stage in ["both", "train"]: - train_blocks = glob.glob(train_data + "/*.h5") - assert len(train_blocks) > 0, "invalid data files or paths." - if len(train_blocks) > 1: - train_blocks.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) # "xx_part_1.h5" - train_gen = DataLoader(feature_map, train_blocks, batch_size=batch_size, shuffle=shuffle, verbose=verbose, **kwargs) + train_gen = DataLoader(feature_map, train_data, batch_size=batch_size, shuffle=shuffle, verbose=verbose, **kwargs) logging.info("Train samples: total/{:d}, blocks/{:d}".format(train_gen.num_samples, train_gen.num_blocks)) if valid_data: - valid_blocks = glob.glob(valid_data + "/*.h5") - if len(valid_blocks) > 1: - valid_blocks.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) - valid_gen = DataLoader(feature_map, valid_blocks, batch_size=batch_size, shuffle=False, verbose=verbose, **kwargs) + valid_gen = DataLoader(feature_map, valid_data, batch_size=batch_size, shuffle=False, verbose=verbose, **kwargs) logging.info("Validation samples: total/{:d}, blocks/{:d}".format(valid_gen.num_samples, valid_gen.num_blocks)) if stage in ["both", "test"]: if test_data: - test_blocks = glob.glob(test_data + "/*.h5") - assert len(test_blocks) > 0, "invalid data files or paths." - if len(test_blocks) > 1: - test_blocks.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) - test_gen = DataLoader(feature_map, test_blocks, batch_size=batch_size, shuffle=False, verbose=verbose, **kwargs) + test_gen = DataLoader(feature_map, test_data, batch_size=batch_size, shuffle=False, verbose=verbose, **kwargs) logging.info("Test samples: total/{:d}, blocks/{:d}".format(test_gen.num_samples, test_gen.num_blocks)) self.train_gen, self.valid_gen, self.test_gen = train_gen, valid_gen, test_gen @@ -118,4 +128,4 @@ def make_iterator(self): return self.test_gen else: logging.info("Loading data done.") - return self.train_gen, self.valid_gen, self.test_gen \ No newline at end of file + return self.train_gen, self.valid_gen, self.test_gen diff --git a/fuxictr/version.py b/fuxictr/version.py index d73e40a..00f54dd 100644 --- a/fuxictr/version.py +++ b/fuxictr/version.py @@ -1 +1 @@ -__version__="2.1.1" +__version__="2.1.2" diff --git a/setup.py b/setup.py index 594f32d..b7ee51d 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="fuxictr", - version="2.1.1", + version="2.1.2", author="fuxictr", author_email="fuxictr@users.noreply.github.com", description="A configurable, tunable, and reproducible library for CTR prediction",