diff --git a/fuxictr/preprocess/build_dataset.py b/fuxictr/preprocess/build_dataset.py index 7f528fa..de3ca67 100644 --- a/fuxictr/preprocess/build_dataset.py +++ b/fuxictr/preprocess/build_dataset.py @@ -21,6 +21,7 @@ import numpy as np import gc import multiprocessing as mp +import polars as pl def split_train_test(train_ddf=None, valid_ddf=None, test_ddf=None, valid_size=0, @@ -53,28 +54,28 @@ def save_npz(darray_dict, data_path): np.savez(data_path, **darray_dict) -def transform_block(feature_encoder, df_block, filename, preprocess=False): - if preprocess: - df_block = feature_encoder.preprocess(df_block) +def transform_block(feature_encoder, df_block, filename): darray_dict = feature_encoder.transform(df_block) save_npz(darray_dict, os.path.join(feature_encoder.data_dir, filename)) -def transform(feature_encoder, ddf, filename, preprocess=False, block_size=0): +def transform(feature_encoder, ddf, filename, block_size=0): if block_size > 0: pool = mp.Pool(mp.cpu_count() // 2) block_id = 0 for idx in range(0, len(ddf), block_size): - df_block = ddf[idx: (idx + block_size)] - pool.apply_async(transform_block, args=(feature_encoder, - df_block, - '{}/part_{:05d}.npz'.format(filename, block_id), - preprocess)) + df_block = ddf.iloc[idx:(idx + block_size)] + pool.apply_async( + transform_block, + args=(feature_encoder, + df_block, + '{}/part_{:05d}.npz'.format(filename, block_id)) + ) block_id += 1 pool.close() pool.join() else: - transform_block(feature_encoder, ddf, filename, preprocess) + transform_block(feature_encoder, ddf, filename) def build_dataset(feature_encoder, train_data=None, valid_data=None, test_data=None, valid_size=0, @@ -96,12 +97,12 @@ def build_dataset(feature_encoder, train_data=None, valid_data=None, test_data=N valid_ddf = feature_encoder.read_csv(valid_data, **kwargs) test_ddf = feature_encoder.read_csv(test_data, **kwargs) train_ddf, valid_ddf, test_ddf = split_train_test(train_ddf, valid_ddf, test_ddf, - valid_size, test_size, split_type) + valid_size, test_size, split_type) # fit and transform train_ddf train_ddf = feature_encoder.preprocess(train_ddf) feature_encoder.fit(train_ddf, **kwargs) - transform(feature_encoder, train_ddf, 'train', preprocess=False, block_size=data_block_size) + transform(feature_encoder, train_ddf, 'train', block_size=data_block_size) del train_ddf gc.collect() @@ -109,7 +110,8 @@ def build_dataset(feature_encoder, train_data=None, valid_data=None, test_data=N if valid_ddf is None and (valid_data is not None): valid_ddf = feature_encoder.read_csv(valid_data, **kwargs) if valid_ddf is not None: - transform(feature_encoder, valid_ddf, 'valid', preprocess=True, block_size=data_block_size) + valid_ddf = feature_encoder.preprocess(valid_ddf) + transform(feature_encoder, valid_ddf, 'valid', block_size=data_block_size) del valid_ddf gc.collect() @@ -117,7 +119,8 @@ def build_dataset(feature_encoder, train_data=None, valid_data=None, test_data=N if test_ddf is None and (test_data is not None): test_ddf = feature_encoder.read_csv(test_data, **kwargs) if test_ddf is not None: - transform(feature_encoder, test_ddf, 'test', preprocess=True, block_size=data_block_size) + test_ddf = feature_encoder.preprocess(test_ddf) + transform(feature_encoder, test_ddf, 'test', block_size=data_block_size) del test_ddf gc.collect() logging.info("Transform csv data to npz done.") diff --git a/fuxictr/preprocess/feature_processor.py b/fuxictr/preprocess/feature_processor.py index 7f6d147..8072c19 100644 --- a/fuxictr/preprocess/feature_processor.py +++ b/fuxictr/preprocess/feature_processor.py @@ -18,12 +18,14 @@ import numpy as np from collections import Counter, OrderedDict import pandas as pd +import polars as pl import pickle import os import logging import json import re import shutil +import glob from pathlib import Path import sklearn.preprocessing as sklearn_preprocess from fuxictr.features import FeatureMap @@ -65,11 +67,14 @@ def _complete_feature_cols(self, feature_cols): full_feature_cols.append(col) return full_feature_cols - def read_csv(self, data_path, sep=",", nrows=None, **kwargs): + def read_csv(self, data_path, sep=",", n_rows=None, **kwargs): logging.info("Reading file: " + data_path) - usecols_fn = lambda x: x in self.dtype_dict - ddf = pd.read_csv(data_path, sep=sep, usecols=usecols_fn, - dtype=object, memory_map=True, nrows=nrows) + file_names = sorted(glob.glob(data_path)) + assert len(file_names) > 0, f"Invalid data path: {data_path}" + # Require python >= 3.8 for use polars to scan multiple csv files + file_names = file_names[0] + ddf = pl.scan_csv(source=file_names, separator=sep, dtypes=self.dtype_dict, + low_memory=False, n_rows=n_rows) return ddf def preprocess(self, ddf): @@ -78,50 +83,38 @@ def preprocess(self, ddf): for col in all_cols: name = col["name"] if name in ddf.columns: - if ddf[name].isnull().values.any(): - ddf[name] = self._fill_na_(col, ddf[name]) - ddf[name] = ddf[name].astype(self.dtype_dict[name]) + fill_na = "" if col["dtype"] in ["str", str] else 0 + fill_na = col.get("fill_na", fill_na) + ddf = ddf.with_columns(pl.col(name).fill_null(fill_na)) if col.get("preprocess"): - preprocess_splits = re.split(r"\(|\)", col["preprocess"]) - preprocess_fn = getattr(self, preprocess_splits[0]) - if len(preprocess_splits) > 1: - ddf[name] = preprocess_fn(ddf, preprocess_splits[1]) - else: - ddf[name] = preprocess_fn(ddf, name) - ddf[name] = ddf[name].astype(self.dtype_dict[name]) + preprocess_args = re.split(r"\(|\)", col["preprocess"]) + preprocess_fn = getattr(self, preprocess_args[0]) + ddf = preprocess_fn(ddf, name, *preprocess_args[1:-1]) + ddf = ddf.with_columns(pl.col(name).cast(self.dtype_dict[name])) active_cols = [col["name"] for col in all_cols if col.get("active") != False] - ddf = ddf.loc[:, active_cols] + ddf = ddf.select(active_cols) return ddf - def _fill_na_(self, col, series): - na_value = col.get("fill_na") - if na_value is not None: - return series.fillna(na_value) - elif col["dtype"] in ["str", str]: - return series.fillna("") - else: - raise RuntimeError("Feature column={} requires to assign fill_na value!".format(col["name"])) - def fit(self, train_ddf, min_categr_count=1, num_buckets=10, **kwargs): logging.info("Fit feature processor...") for col in self.feature_cols: name = col["name"] if col["active"]: logging.info("Processing column: {}".format(col)) + col_series = train_ddf.select(name).collect().to_series().to_pandas() if col["type"] == "meta": # e.g. group_id - self.fit_meta_col(col, train_ddf[name].values) + self.fit_meta_col(col) elif col["type"] == "numeric": - self.fit_numeric_col(col, train_ddf[name].values) + self.fit_numeric_col(col, col_series) elif col["type"] == "categorical": - - self.fit_categorical_col(col, train_ddf[name].values, + self.fit_categorical_col(col, col_series, min_categr_count=min_categr_count, num_buckets=num_buckets) elif col["type"] == "sequence": - self.fit_sequence_col(col, train_ddf[name].values, + self.fit_sequence_col(col, col_series, min_categr_count=min_categr_count) else: - raise NotImplementedError("feature_col={}".format(feature_col)) + raise NotImplementedError("feature type={}".format(col["type"])) # Expand vocab from pretrained_emb os.makedirs(self.data_dir, exist_ok=True) @@ -166,17 +159,16 @@ def fit(self, train_ddf, min_categr_count=1, num_buckets=10, **kwargs): self.feature_map.save(self.json_file) logging.info("Set feature processor done.") - def fit_meta_col(self, col, col_values): + def fit_meta_col(self, col): name = col["name"] feature_type = col["type"] self.feature_map.features[name] = {"type": feature_type} - # assert col.get("remap") == False, "Meta feature currently only supports `remap=False`, \ - # since it needs to map train and validation sets together." if col.get("remap", True): + # No need to fit, update vocab in encode_meta() tokenizer = Tokenizer(min_freq=1, remap=True) self.processor_dict[name + "::tokenizer"] = tokenizer - def fit_numeric_col(self, col, col_values): + def fit_numeric_col(self, col, col_series): name = col["name"] feature_type = col["type"] feature_source = col.get("source", "") @@ -186,10 +178,10 @@ def fit_numeric_col(self, col, col_values): self.feature_map.features[name]["feature_encoder"] = col["feature_encoder"] if "normalizer" in col: normalizer = Normalizer(col["normalizer"]) - normalizer.fit(col_values) + normalizer.fit(col_series.dropna().values) self.processor_dict[name + "::normalizer"] = normalizer - def fit_categorical_col(self, col, col_values, min_categr_count=1, num_buckets=10): + def fit_categorical_col(self, col, col_series, min_categr_count=1, num_buckets=10): name = col["name"] feature_type = col["type"] feature_source = col.get("source", "") @@ -206,7 +198,7 @@ def fit_categorical_col(self, col, col_values, min_categr_count=1, num_buckets=1 tokenizer = Tokenizer(min_freq=min_categr_count, na_value=col.get("fill_na", ""), remap=col.get("remap", True)) - tokenizer.fit_on_texts(col_values) + tokenizer.fit_on_texts(col_series) if "share_embedding" in col: self.feature_map.features[name]["share_embedding"] = col["share_embedding"] tknzr_name = col["share_embedding"] + "::tokenizer" @@ -225,20 +217,18 @@ def fit_categorical_col(self, col, col_values, min_categr_count=1, num_buckets=1 if category_processor == "quantile_bucket": # transform numeric value to bucket num_buckets = col.get("num_buckets", num_buckets) qtf = sklearn_preprocess.QuantileTransformer(n_quantiles=num_buckets + 1) - qtf.fit(col_values) + qtf.fit(col_series.values) boundaries = qtf.quantiles_[1:-1] self.feature_map.features[name]["vocab_size"] = num_buckets self.processor_dict[name + "::boundaries"] = boundaries elif category_processor == "hash_bucket": num_buckets = col.get("num_buckets", num_buckets) - uniques = Counter(col_values) - num_buckets = min(num_buckets, len(uniques)) self.feature_map.features[name]["vocab_size"] = num_buckets self.processor_dict[name + "::num_buckets"] = num_buckets else: raise NotImplementedError("category_processor={} not supported.".format(category_processor)) - def fit_sequence_col(self, col, col_values, min_categr_count=1): + def fit_sequence_col(self, col, col_series, min_categr_count=1): name = col["name"] feature_type = col["type"] feature_source = col.get("source", "") @@ -259,7 +249,7 @@ def fit_sequence_col(self, col, col_values, min_categr_count=1): tokenizer = Tokenizer(min_freq=min_categr_count, splitter=splitter, na_value=na_value, max_len=max_len, padding=padding, remap=col.get("remap", True)) - tokenizer.fit_on_texts(col_values) + tokenizer.fit_on_texts(col_series) if "share_embedding" in col: self.feature_map.features[name]["share_embedding"] = col["share_embedding"] tknzr_name = col["share_embedding"] + "::tokenizer" @@ -275,22 +265,22 @@ def fit_sequence_col(self, col, col_values, min_categr_count=1): "vocab_size": tokenizer.vocab_size()}) def transform(self, ddf): - logging.info("Transform feature columns...") + logging.info("Transform feature columns with ID mapping...") data_dict = dict() for feature, feature_spec in self.feature_map.features.items(): if feature in ddf.columns: feature_type = feature_spec["type"] - col_values = ddf.loc[:, feature].values + col_series = ddf[feature] if feature_type == "meta": if feature + "::tokenizer" in self.processor_dict: tokenizer = self.processor_dict[feature + "::tokenizer"] - data_dict[feature] = tokenizer.encode_meta(col_values) + data_dict[feature] = tokenizer.encode_meta(col_series) # Update vocab in tokenizer self.processor_dict[feature + "::tokenizer"] = tokenizer else: - data_dict[feature] = col_values.astype(self.dtype_dict[feature]) + data_dict[feature] = col_series.values elif feature_type == "numeric": - col_values = col_values.astype(float) + col_values = col_series.values normalizer = self.processor_dict.get(feature + "::normalizer") if normalizer: col_values = normalizer.transform(col_values) @@ -298,16 +288,16 @@ def transform(self, ddf): elif feature_type == "categorical": category_processor = feature_spec.get("category_processor") if category_processor is None: - data_dict[feature] = self.processor_dict.get(feature + "::tokenizer").encode_category(col_values) + data_dict[feature] = self.processor_dict.get(feature + "::tokenizer").encode_category(col_series) elif category_processor == "numeric_bucket": raise NotImplementedError elif category_processor == "hash_bucket": raise NotImplementedError elif feature_type == "sequence": - data_dict[feature] = self.processor_dict.get(feature + "::tokenizer").encode_sequence(col_values) + data_dict[feature] = self.processor_dict.get(feature + "::tokenizer").encode_sequence(col_series) for label in self.feature_map.labels: if label in ddf.columns: - data_dict[label] = ddf.loc[:, label].values + data_dict[label] = ddf[label].values return data_dict def load_pickle(self, pickle_file=None): @@ -335,6 +325,6 @@ def save_vocab(self, vocab_file): with open(vocab_file, "w") as fd: fd.write(json.dumps(vocab, indent=4)) - def copy_from(self, ddf, src_name): - return ddf[src_name] - + def copy_from(self, ddf, name, src_name): + ddf = ddf.with_columns(pl.col(src_name).alias(name)) + return ddf diff --git a/fuxictr/preprocess/normalizer.py b/fuxictr/preprocess/normalizer.py index a13d388..b34bf39 100644 --- a/fuxictr/preprocess/normalizer.py +++ b/fuxictr/preprocess/normalizer.py @@ -33,8 +33,7 @@ def __init__(self, normalizer): def fit(self, X): if not self.callable: - null_index = np.isnan(X) - self.normalizer.fit(X[~null_index].reshape(-1, 1)) + self.normalizer.fit(X.reshape(-1, 1)) def transform(self, X): if self.callable: diff --git a/fuxictr/preprocess/tokenizer.py b/fuxictr/preprocess/tokenizer.py index 90aed37..c23bd59 100644 --- a/fuxictr/preprocess/tokenizer.py +++ b/fuxictr/preprocess/tokenizer.py @@ -16,16 +16,17 @@ from collections import Counter import numpy as np -import pandas as pd import h5py from tqdm import tqdm +import polars as pl from keras_preprocessing.sequence import pad_sequences from concurrent.futures import ProcessPoolExecutor, as_completed +import multiprocessing as mp class Tokenizer(object): def __init__(self, max_features=None, na_value="", min_freq=1, splitter=None, remap=True, - lower=False, max_len=0, padding="pre", num_workers=8): + lower=False, max_len=0, padding="pre"): self._max_features = max_features self._na_value = na_value self._min_freq = min_freq @@ -34,24 +35,23 @@ def __init__(self, max_features=None, na_value="", min_freq=1, splitter=None, re self.vocab = dict() self.max_len = max_len self.padding = padding - self.num_workers = num_workers self.remap = remap - def fit_on_texts(self, texts): + def fit_on_texts(self, series): + max_len = 0 word_counts = Counter() - if self._splitter is not None: # for sequence - max_len = 0 - with ProcessPoolExecutor(max_workers=self.num_workers) as executor: - chunks = np.array_split(texts, self.num_workers) - tasks = [executor.submit(count_tokens, chunk, self._splitter) for chunk in chunks] - for future in tqdm(as_completed(tasks), total=len(tasks)): - block_word_counts, block_max_len = future.result() - word_counts.update(block_word_counts) - max_len = max(max_len, block_max_len) - if self.max_len == 0: # if argument max_len not given - self.max_len = max_len - else: - word_counts = Counter(list(texts)) + with ProcessPoolExecutor(max_workers=(mp.cpu_count() // 2)) as executor: + chunk_size = 1000000 + tasks = [] + for idx in range(0, len(series), chunk_size): + data_chunk = series.iloc[idx: (idx + chunk_size)] + tasks.append(executor.submit(count_tokens, data_chunk, self._splitter)) + for future in tqdm(as_completed(tasks), total=len(tasks)): + chunk_word_counts, chunk_max_len = future.result() + word_counts.update(chunk_word_counts) + max_len = max(max_len, chunk_max_len) + if self.max_len == 0: # if argument max_len not given + self.max_len = max_len self.build_vocab(word_counts) def build_vocab(self, word_counts): @@ -101,30 +101,28 @@ def update_vocab(self, word_list): if new_words > 0: self.vocab["__OOV__"] = self.vocab_size() - def encode_meta(self, values): - word_counts = Counter(list(values)) + def encode_meta(self, series): + word_counts = dict(series.value_counts()) if len(self.vocab) == 0: self.build_vocab(word_counts) else: # for considering meta data in test data self.update_vocab(word_counts.keys()) - meta_values = [self.vocab.get(x, self.vocab["__OOV__"]) for x in values] - return np.array(meta_values) + series = series.map(lambda x: self.vocab.get(x, self.vocab["__OOV__"])) + return series.values - def encode_category(self, categories): - category_indices = [self.vocab.get(x, self.vocab["__OOV__"]) for x in categories] - return np.array(category_indices) + def encode_category(self, series): + series = series.map(lambda x: self.vocab.get(x, self.vocab["__OOV__"])) + return series.values - def encode_sequence(self, texts): - sequence_list = [] - for text in texts: - if pd.isnull(text) or text == '': - sequence_list.append([]) - else: - sequence_list.append([self.vocab.get(x, self.vocab["__OOV__"]) if x != self._na_value \ - else self.vocab["__PAD__"] for x in text.split(self._splitter)]) - sequence_list = pad_sequences(sequence_list, maxlen=self.max_len, value=self.vocab["__PAD__"], - padding=self.padding, truncating=self.padding) - return np.array(sequence_list) + def encode_sequence(self, series): + series = series.map( + lambda text: [self.vocab.get(x, self.vocab["__OOV__"]) if x != self._na_value \ + else self.vocab["__PAD__"] for x in text.split(self._splitter)] + ) + seqs = pad_sequences(series.to_list(), maxlen=self.max_len, + value=self.vocab["__PAD__"], + padding=self.padding, truncating=self.padding) + return np.array(seqs) def load_pretrained_vocab(self, feature_dtype, pretrain_path, expand_vocab=True): if pretrain_path.endswith(".h5"): @@ -144,12 +142,12 @@ def load_pretrained_vocab(self, feature_dtype, pretrain_path, expand_vocab=True) vocab_size += 1 -def count_tokens(texts, splitter): - word_counts = Counter() +def count_tokens(series, splitter=None): max_len = 0 - for text in texts: - text_split = text.split(splitter) - max_len = max(max_len, len(text_split)) - for token in text_split: - word_counts[token] += 1 - return word_counts, max_len + if splitter is not None: # for sequence + series = series.map(lambda text: text.split(splitter)) + max_len = series.str.len().max() + word_counts = series.explode().value_counts() + else: + word_counts = series.value_counts() + return dict(word_counts), max_len diff --git a/fuxictr/pytorch/dataloaders/npz_block_dataloader.py b/fuxictr/pytorch/dataloaders/npz_block_dataloader.py index 37bf697..eb81592 100644 --- a/fuxictr/pytorch/dataloaders/npz_block_dataloader.py +++ b/fuxictr/pytorch/dataloaders/npz_block_dataloader.py @@ -74,6 +74,8 @@ def __init__(self, feature_map, data_path, batch_size=32, shuffle=False, datapipe = BlockDataPipe(data_blocks, feature_map) if shuffle: datapipe = datapipe.shuffle(buffer_size=buffer_size) + else: + num_workers = 1 # multiple workers cannot keep the order of data reading super(NpzBlockDataLoader, self).__init__(dataset=datapipe, batch_size=batch_size, num_workers=num_workers) diff --git a/fuxictr/version.py b/fuxictr/version.py index 6daf179..1d42921 100644 --- a/fuxictr/version.py +++ b/fuxictr/version.py @@ -1 +1 @@ -__version__="2.2.1" +__version__="2.2.2" diff --git a/setup.py b/setup.py index 711fb38..df81f7b 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="fuxictr", - version="2.2.1", + version="2.2.2", author="RECZOO", author_email="reczoo@users.noreply.github.com", description="A configurable, tunable, and reproducible library for CTR prediction",