Skip to content

Commit

Permalink
Update H5DataBlockLoader to support multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
xpai committed Oct 29, 2023
1 parent 0d255f0 commit bec8895
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 31 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
<hr/>

<div align="center">
<a href="https://github.com/xue-pai/FuxiCTR/stargazers"><img src="https://reporoster.com/stars/xue-pai/FuxiCTR" /><a/>
<a href="https://github.com/xue-pai/FuxiCTR/stargazers"><img src="http://bytecrank.com/nastyox/reporoster/php/stargazersSVG.php?user=xue-pai&repo=FuxiCTR" width="600"/><a/>
</div>

Click-through rate (CTR) prediction is a critical task for many industrial applications such as online advertising, recommender systems, and sponsored search. FuxiCTR provides an open-source library for CTR prediction, with key features in configurability, tunability, and reproducibility. We hope this project could benefit both researchers and practitioners with the goal of open benchmarking for CTR prediction tasks.
Expand Down
66 changes: 38 additions & 28 deletions fuxictr/pytorch/dataloaders/h5_block_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,18 @@
import h5py
from itertools import chain
import torch
from torch.utils import data
import logging
import glob


class DataLoader(object):
def __init__(self, feature_map, data_block_list, batch_size=32, shuffle=False, verbose=0, **kwargs):
# data_block_list: path list of data blocks
class DataBlockDataset(data.IterableDataset):
def __init__(self, feature_map, block_file_list, shuffle=False, verbose=0):
# block_file_list: path list of data blocks
self.feature_map = feature_map
self.data_blocks = data_block_list
self.data_blocks = block_file_list
self.shuffle = shuffle
self.batch_size = batch_size
self.verbose = verbose
self.num_blocks = len(self.data_blocks)
self.num_batches, self.num_samples = self.count_batches_and_samples()

def load_data_array(self, data_path):
data_dict = load_h5(data_path, verbose=self.verbose)
Expand All @@ -54,14 +52,37 @@ def iter_block(self, data_block):
indexes = list(range(block_size))
if self.shuffle:
np.random.shuffle(indexes)
for idx in range(0, block_size, self.batch_size):
batch_index = indexes[idx:(idx + self.batch_size)]
yield darray[batch_index, :]
for idx in indexes:
yield darray[idx, :]

def __iter__(self):
if self.shuffle:
np.random.shuffle(self.data_blocks)
return chain.from_iterable(map(self.iter_block, self.data_blocks))
worker_info = data.get_worker_info()
if worker_info is None: # single-process data loading
chunk_list = self.data_blocks
else: # in a worker process
worker_id = worker_info.id
chunk_list = np.array_split(self.data_blocks, worker_info.num_workers)
sub_list = chunk_list[worker_id].tolist()
if self.shuffle:
np.random.shuffle(sub_list)
return chain.from_iterable(map(self.iter_block, sub_list))


class DataLoader(data.DataLoader):
def __init__(self, feature_map, data_path, batch_size=32, shuffle=False,
num_workers=1, verbose=0, **kwargs):
data_blocks = glob.glob(data_path + "/*.h5")
assert len(data_blocks) > 0, f"invalid data_path: {data_path}"
if len(data_blocks) > 1:
data_blocks.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) # e.g. "part_1.h5"
self.data_blocks = data_blocks
self.num_blocks = len(self.data_blocks)
self.feature_map = feature_map
self.batch_size = batch_size
self.num_batches, self.num_samples = self.count_batches_and_samples()
self.dataset = DataBlockDataset(feature_map, data_blocks, shuffle=shuffle, verbose=verbose)
super(DataLoader, self).__init__(dataset=self.dataset, batch_size=batch_size,
num_workers=num_workers)

def __len__(self):
return self.num_batches
Expand All @@ -86,26 +107,15 @@ def __init__(self, feature_map, stage="both", train_data=None, valid_data=None,
test_gen = None
self.stage = stage
if stage in ["both", "train"]:
train_blocks = glob.glob(train_data + "/*.h5")
assert len(train_blocks) > 0, "invalid data files or paths."
if len(train_blocks) > 1:
train_blocks.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) # "xx_part_1.h5"
train_gen = DataLoader(feature_map, train_blocks, batch_size=batch_size, shuffle=shuffle, verbose=verbose, **kwargs)
train_gen = DataLoader(feature_map, train_data, batch_size=batch_size, shuffle=shuffle, verbose=verbose, **kwargs)
logging.info("Train samples: total/{:d}, blocks/{:d}".format(train_gen.num_samples, train_gen.num_blocks))
if valid_data:
valid_blocks = glob.glob(valid_data + "/*.h5")
if len(valid_blocks) > 1:
valid_blocks.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
valid_gen = DataLoader(feature_map, valid_blocks, batch_size=batch_size, shuffle=False, verbose=verbose, **kwargs)
valid_gen = DataLoader(feature_map, valid_data, batch_size=batch_size, shuffle=False, verbose=verbose, **kwargs)
logging.info("Validation samples: total/{:d}, blocks/{:d}".format(valid_gen.num_samples, valid_gen.num_blocks))

if stage in ["both", "test"]:
if test_data:
test_blocks = glob.glob(test_data + "/*.h5")
assert len(test_blocks) > 0, "invalid data files or paths."
if len(test_blocks) > 1:
test_blocks.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
test_gen = DataLoader(feature_map, test_blocks, batch_size=batch_size, shuffle=False, verbose=verbose, **kwargs)
test_gen = DataLoader(feature_map, test_data, batch_size=batch_size, shuffle=False, verbose=verbose, **kwargs)
logging.info("Test samples: total/{:d}, blocks/{:d}".format(test_gen.num_samples, test_gen.num_blocks))
self.train_gen, self.valid_gen, self.test_gen = train_gen, valid_gen, test_gen

Expand All @@ -118,4 +128,4 @@ def make_iterator(self):
return self.test_gen
else:
logging.info("Loading data done.")
return self.train_gen, self.valid_gen, self.test_gen
return self.train_gen, self.valid_gen, self.test_gen
2 changes: 1 addition & 1 deletion fuxictr/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__="2.1.1"
__version__="2.1.2"
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="fuxictr",
version="2.1.1",
version="2.1.2",
author="fuxictr",
author_email="fuxictr@users.noreply.github.com",
description="A configurable, tunable, and reproducible library for CTR prediction",
Expand Down

0 comments on commit bec8895

Please sign in to comment.