Skip to content

Commit

Permalink
Update H5DataBlockLoader to support multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
xpai committed Oct 26, 2023
1 parent 0d255f0 commit 7e16dfc
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 29 deletions.
66 changes: 39 additions & 27 deletions fuxictr/pytorch/dataloaders/h5_block_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@
import h5py
from itertools import chain
import torch
from torch.utils import data
import logging
import glob


class DataLoader(object):
def __init__(self, feature_map, data_block_list, batch_size=32, shuffle=False, verbose=0, **kwargs):
# data_block_list: path list of data blocks
class DataBlockDataset(data.IterableDataset):
def __init__(self, feature_map, block_file_list, batch_size=32, shuffle=False, verbose=0):
# block_file_list: path list of data blocks
self.feature_map = feature_map
self.data_blocks = data_block_list
self.data_blocks = block_file_list
self.shuffle = shuffle
self.batch_size = batch_size
# self.batch_size = batch_size
self.verbose = verbose
self.num_blocks = len(self.data_blocks)
self.num_batches, self.num_samples = self.count_batches_and_samples()

def load_data_array(self, data_path):
data_dict = load_h5(data_path, verbose=self.verbose)
Expand All @@ -54,14 +54,37 @@ def iter_block(self, data_block):
indexes = list(range(block_size))
if self.shuffle:
np.random.shuffle(indexes)
for idx in range(0, block_size, self.batch_size):
batch_index = indexes[idx:(idx + self.batch_size)]
yield darray[batch_index, :]
for idx in indexes:
yield darray[idx, :]

def __iter__(self):
if self.shuffle:
np.random.shuffle(self.data_blocks)
return chain.from_iterable(map(self.iter_block, self.data_blocks))
# if self.shuffle:
# np.random.shuffle(self.data_blocks)
worker_info = data.get_worker_info()
if worker_info is None: # single-process data loading
chunk_list = self.data_blocks
else: # in a worker process
worker_id = worker_info.id
chunk_list = np.array_split(self.data_blocks, worker_info.num_workers)
sub_list = chunk_list[worker_id].tolist()
return chain.from_iterable(map(self.iter_block, sub_list))


class DataLoader(data.DataLoader):
def __init__(self, feature_map, data_path, batch_size=32, shuffle=False,
num_workers=1, verbose=0, **kwargs):
data_blocks = glob.glob(data_path + "/*.h5")
assert len(data_blocks) > 0, f"invalid data_path: {data_path}"
if len(data_blocks) > 1:
data_blocks.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) # e.g. "part_1.h5"
self.data_blocks = data_blocks
self.feature_map = feature_map
self.batch_size = batch_size
self.dataset = DataBlockDataset(feature_map, data_blocks, batch_size=batch_size,
shuffle=shuffle, verbose=verbose)
self.num_batches, self.num_samples = self.count_batches_and_samples()
super(DataLoader, self).__init__(dataset=self.dataset, batch_size=batch_size,
num_workers=num_workers)

def __len__(self):
return self.num_batches
Expand All @@ -86,26 +109,15 @@ def __init__(self, feature_map, stage="both", train_data=None, valid_data=None,
test_gen = None
self.stage = stage
if stage in ["both", "train"]:
train_blocks = glob.glob(train_data + "/*.h5")
assert len(train_blocks) > 0, "invalid data files or paths."
if len(train_blocks) > 1:
train_blocks.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) # "xx_part_1.h5"
train_gen = DataLoader(feature_map, train_blocks, batch_size=batch_size, shuffle=shuffle, verbose=verbose, **kwargs)
train_gen = DataLoader(feature_map, train_data, batch_size=batch_size, shuffle=shuffle, verbose=verbose, **kwargs)
logging.info("Train samples: total/{:d}, blocks/{:d}".format(train_gen.num_samples, train_gen.num_blocks))
if valid_data:
valid_blocks = glob.glob(valid_data + "/*.h5")
if len(valid_blocks) > 1:
valid_blocks.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
valid_gen = DataLoader(feature_map, valid_blocks, batch_size=batch_size, shuffle=False, verbose=verbose, **kwargs)
valid_gen = DataLoader(feature_map, valid_data, batch_size=batch_size, shuffle=False, verbose=verbose, **kwargs)
logging.info("Validation samples: total/{:d}, blocks/{:d}".format(valid_gen.num_samples, valid_gen.num_blocks))

if stage in ["both", "test"]:
if test_data:
test_blocks = glob.glob(test_data + "/*.h5")
assert len(test_blocks) > 0, "invalid data files or paths."
if len(test_blocks) > 1:
test_blocks.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
test_gen = DataLoader(feature_map, test_blocks, batch_size=batch_size, shuffle=False, verbose=verbose, **kwargs)
test_gen = DataLoader(feature_map, test_data, batch_size=batch_size, shuffle=False, verbose=verbose, **kwargs)
logging.info("Test samples: total/{:d}, blocks/{:d}".format(test_gen.num_samples, test_gen.num_blocks))
self.train_gen, self.valid_gen, self.test_gen = train_gen, valid_gen, test_gen

Expand All @@ -118,4 +130,4 @@ def make_iterator(self):
return self.test_gen
else:
logging.info("Loading data done.")
return self.train_gen, self.valid_gen, self.test_gen
return self.train_gen, self.valid_gen, self.test_gen
2 changes: 1 addition & 1 deletion fuxictr/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__="2.1.1"
__version__="2.1.2"
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="fuxictr",
version="2.1.1",
version="2.1.2",
author="fuxictr",
author_email="fuxictr@users.noreply.github.com",
description="A configurable, tunable, and reproducible library for CTR prediction",
Expand Down

0 comments on commit 7e16dfc

Please sign in to comment.