Skip to content

Commit

Permalink
Updates:
Browse files Browse the repository at this point in the history
1. Update to allow loading pretrained h5 directly in PretrainedEmbedding (skip key mapping in preprocess)
2. Update to allow data_path to be a directory path
  • Loading branch information
xpai committed Oct 25, 2023
1 parent 60dc5f8 commit 26c4423
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 82 deletions.
14 changes: 4 additions & 10 deletions fuxictr/preprocess/build_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,7 @@ def build_dataset(feature_encoder, train_data=None, valid_data=None, test_data=N
logging.info("Transform csv data to h5 done.")

# Return processed data splits
if data_block_size > 0:
return os.path.join(feature_encoder.data_dir, "train/*.h5"), \
os.path.join(feature_encoder.data_dir, "valid/*.h5"), \
os.path.join(feature_encoder.data_dir, "test/*.h5") if (
test_data or test_size > 0) else None
else:
return os.path.join(feature_encoder.data_dir, 'train.h5'), \
os.path.join(feature_encoder.data_dir, 'valid.h5'), \
os.path.join(feature_encoder.data_dir, 'test.h5') if (
test_data or test_size > 0) else None
return os.path.join(feature_encoder.data_dir, "train"), \
os.path.join(feature_encoder.data_dir, "valid"), \
os.path.join(feature_encoder.data_dir, "test") if (
test_data or test_size > 0) else None
11 changes: 5 additions & 6 deletions fuxictr/preprocess/feature_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import logging
import json
import re
import shutil
import sklearn.preprocessing as sklearn_preprocess
from fuxictr.features import FeatureMap
from .tokenizer import Tokenizer
Expand Down Expand Up @@ -127,15 +128,13 @@ def fit(self, train_ddf, min_categr_count=1, num_buckets=10, **kwargs):
logging.info("Loading pretrained embedding: " + name)
if "pretrain_dim" in col:
self.feature_map.features[name]["pretrain_dim"] = col["pretrain_dim"]
self.feature_map.features[name]["pretrained_emb"] = "pretrained_emb.h5"
shutil.copytree(col["pretrained_emb"],
os.path.join(self.data_dir, os.path.basename(col["pretrained_emb"])))
self.feature_map.features[name]["pretrained_emb"] = os.path.basename(col["pretrained_emb"])
self.feature_map.features[name]["freeze_emb"] = col.get("freeze_emb", True)
self.feature_map.features[name]["pretrain_usage"] = col.get("pretrain_usage", "init")
tokenizer = self.processor_dict[name + "::tokenizer"]
tokenizer.load_pretrained_embedding(name,
self.dtype_dict[name],
col["pretrained_emb"],
os.path.join(self.data_dir, "pretrained_emb.h5"),
freeze_emb=col.get("freeze_emb", True))
tokenizer.load_pretrained_vocab(self.dtype_dict[name], col["pretrained_emb"])
self.processor_dict[name + "::tokenizer"] = tokenizer
self.feature_map.features[name]["vocab_size"] = tokenizer.vocab_size()

Expand Down
44 changes: 7 additions & 37 deletions fuxictr/preprocess/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,10 @@
# =========================================================================

from collections import Counter
import itertools
import numpy as np
import pandas as pd
import h5py
import pickle
import os
from tqdm import tqdm
import logging
import sklearn.preprocessing as sklearn_preprocess
from keras_preprocessing.sequence import pad_sequences
from concurrent.futures import ProcessPoolExecutor, as_completed

Expand Down Expand Up @@ -106,12 +101,6 @@ def update_vocab(self, word_list):
if new_words > 0:
self.vocab["__OOV__"] = self.vocab_size()

def expand_pretrain_vocab(self, word_list):
# Do not update OOV index here
for word in word_list:
if word not in self.vocab:
self.vocab[word] = self.vocab_size()

def encode_meta(self, values):
word_counts = Counter(list(values))
if len(self.vocab) == 0:
Expand All @@ -137,35 +126,16 @@ def encode_sequence(self, texts):
padding=self.padding, truncating=self.padding)
return np.array(sequence_list)

def load_pretrained_embedding(self, feature_name, feature_dtype, pretrain_path,
output_path, freeze_emb=True, expand_pretrain_vocab=True):
def load_pretrained_vocab(self, feature_dtype, pretrain_path, expand_vocab=True):
with h5py.File(pretrain_path, 'r') as hf:
keys = hf["key"][:]
keys = keys.astype(feature_dtype) # in case mismatch of dtype between int and str
pretrained_vocab = dict(zip(keys, range(len(keys))))
pretrained_emb = hf["value"][:]
# update vocab with pretrained keys, in case new token ids appear in validation or test set
if expand_pretrain_vocab:
self.expand_pretrain_vocab(pretrained_vocab.keys())

logging.info("{}\'s pretrained_emb shape: {}".format(feature_name, pretrained_emb.shape))
embedding_dim = pretrained_emb.shape[1]
if freeze_emb:
embedding_matrix = np.zeros((self.vocab_size(), embedding_dim))
else:
embedding_matrix = np.random.normal(loc=0, scale=1.e-4, size=(self.vocab_size(), embedding_dim))
embedding_matrix[self.vocab["__PAD__"], :] = 0. # set as zero embedding for PAD
for word in pretrained_vocab.keys():
if word in self.vocab:
embedding_matrix[self.vocab[word]] = pretrained_emb[pretrained_vocab[word]]
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with h5py.File(output_path, 'a') as hf: # Add different embeddings to a single h5 file
hf.create_dataset(feature_name, data=embedding_matrix)

# def load_vocab(self, vocab_file):
# with open(vocab_file, 'r') as fid:
# word_counts = json.load(fid)
# self.build_vocab(word_counts)
# Update vocab with pretrained keys in case new tokens appear in validation or test set
# Do not update OOV index here since it is used in PretrainedEmbedding
if expand_vocab:
for word in keys:
if word not in self.vocab:
self.vocab[word] = self.vocab_size()


def count_tokens(texts, splitter):
Expand Down
6 changes: 3 additions & 3 deletions fuxictr/pytorch/dataloaders/h5_block_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,22 @@ def __init__(self, feature_map, stage="both", train_data=None, valid_data=None,
test_gen = None
self.stage = stage
if stage in ["both", "train"]:
train_blocks = glob.glob(train_data)
train_blocks = glob.glob(train_data + "/*.h5")
assert len(train_blocks) > 0, "invalid data files or paths."
if len(train_blocks) > 1:
train_blocks.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) # "xx_part_1.h5"
train_gen = DataLoader(feature_map, train_blocks, batch_size=batch_size, shuffle=shuffle, verbose=verbose, **kwargs)
logging.info("Train samples: total/{:d}, blocks/{:d}".format(train_gen.num_samples, train_gen.num_blocks))
if valid_data:
valid_blocks = glob.glob(valid_data)
valid_blocks = glob.glob(valid_data + "/*.h5")
if len(valid_blocks) > 1:
valid_blocks.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
valid_gen = DataLoader(feature_map, valid_blocks, batch_size=batch_size, shuffle=False, verbose=verbose, **kwargs)
logging.info("Validation samples: total/{:d}, blocks/{:d}".format(valid_gen.num_samples, valid_gen.num_blocks))

if stage in ["both", "test"]:
if test_data:
test_blocks = glob.glob(test_data)
test_blocks = glob.glob(test_data + "/*.h5")
assert len(test_blocks) > 0, "invalid data files or paths."
if len(test_blocks) > 1:
test_blocks.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
Expand Down
2 changes: 2 additions & 0 deletions fuxictr/pytorch/dataloaders/h5_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def load_data_array(self, data_path):

class DataLoader(data.DataLoader):
def __init__(self, feature_map, data_path, batch_size=32, shuffle=False, num_workers=1, **kwargs):
if not data_path.endswith(".h5"):
data_path += ".h5"
self.dataset = Dataset(feature_map, data_path)
super(DataLoader, self).__init__(dataset=self.dataset, batch_size=batch_size,
shuffle=shuffle, num_workers=num_workers)
Expand Down
9 changes: 6 additions & 3 deletions fuxictr/pytorch/layers/embeddings/feature_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,16 @@ def __init__(self,
self.embedding_layers[feature] = nn.Linear(1, feat_dim, bias=False)
elif feature_spec["type"] in ["categorical", "sequence"]:
if use_pretrain and "pretrained_emb" in feature_spec:
pretrained_path = os.path.join(feature_map.data_dir,
feature_spec["pretrained_emb"])
pretrain_path = os.path.join(feature_map.data_dir,
feature_spec["pretrained_emb"])
vocab_path = os.path.join(feature_map.data_dir,
"feature_vocab.json")
pretrain_dim = feature_spec.get("pretrain_dim", feat_dim)
pretrain_usage = feature_spec.get("pretrain_usage", "init")
self.embedding_layers[feature] = PretrainedEmbedding(feature,
feature_spec,
pretrained_path,
pretrain_path,
vocab_path,
feat_dim,
pretrain_dim,
pretrain_usage)
Expand Down
61 changes: 40 additions & 21 deletions fuxictr/pytorch/layers/embeddings/pretrained_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@
from torch import nn
import h5py
import os
import io
import json
import numpy as np
import logging


class PretrainedEmbedding(nn.Module):
def __init__(self,
feature_name,
feature_spec,
pretrained_path,
pretrain_path,
vocab_path,
embedding_dim,
pretrain_dim,
pretrain_usage="init"):
Expand All @@ -40,11 +44,10 @@ def __init__(self,
padding_idx = feature_spec.get("padding_idx", None)
self.oov_idx = feature_spec["oov_idx"]
self.freeze_emb = feature_spec["freeze_emb"]
embedding_matrix = nn.Embedding(feature_spec["vocab_size"],
pretrain_dim,
padding_idx=padding_idx)
self.pretrain_embedding = self.load_pretrained_embedding(embedding_matrix,
pretrained_path,
self.pretrain_embedding = self.load_pretrained_embedding(feature_spec["vocab_size"],
pretrain_dim,
pretrain_path,
vocab_path,
feature_name,
freeze=self.freeze_emb,
padding_idx=padding_idx)
Expand All @@ -63,23 +66,39 @@ def reset_parameters(self, embedding_initializer):
nn.init.zeros_(self.id_embedding.weight) # set oov token embeddings to zeros
embedding_initializer(self.id_embedding.weight[1:self.oov_idx, :])

def get_pretrained_embedding(self, pretrained_path, feature_name):
with h5py.File(pretrained_path, 'r') as hf:
embeddings = hf[feature_name][:]
return embeddings
def get_pretrained_embedding(self, pretrain_path):
with h5py.File(pretrain_path, 'r') as hf:
keys = hf["key"][:]
embeddings = hf["value"][:]
logging.info("Loading pretrained_emb: {}, shape: {}".format(pretrain_path, embeddings.shape))
return keys, embeddings

def load_pretrained_embedding(self, embedding_matrix, pretrained_path, feature_name,
freeze=False, padding_idx=None):
embeddings = self.get_pretrained_embedding(pretrained_path, feature_name)
if padding_idx is not None:
embeddings[padding_idx] = np.zeros(embeddings.shape[-1])
assert embeddings.shape[-1] == embedding_matrix.embedding_dim, \
"{}\'s pretrain_dim is not correct.".format(feature_name)
embeddings = torch.from_numpy(embeddings).float()
embedding_matrix.weight = torch.nn.Parameter(embeddings)
def load_feature_vocab(self, vocab_path, feature_name):
with io.open(vocab_path, "r", encoding="utf-8") as fd:
vocab = json.load(fd)
return vocab[feature_name]

def load_pretrained_embedding(self, vocab_size, pretrain_dim, pretrain_path, vocab_path,
feature_name, freeze=False, padding_idx=None):
embedding_layer = nn.Embedding(vocab_size,
pretrain_dim,
padding_idx=padding_idx)
if freeze:
embedding_matrix = np.zeros((vocab_size, pretrain_dim))
else:
embedding_matrix = np.random.normal(loc=0, scale=1.e-4, size=(vocab_size, pretrain_dim))
if padding_idx:
embedding_matrix[padding_idx, :] = np.zeros(pretrain_dim) # set as zero for PAD
keys, embeddings = self.get_pretrained_embedding(pretrain_path)
assert embeddings.shape == pretrain_dim, f"pretrain_dim={pretrain_dim} not correct."
vocab = self.load_feature_vocab(vocab_path, feature_name)
for idx, word in enumerate(keys):
if word in vocab:
embedding_matrix[vocab[word]] = embeddings[idx]
embedding_layer.weight = torch.nn.Parameter(torch.from_numpy(embedding_matrix).float())
if freeze:
embedding_matrix.weight.requires_grad = False
return embedding_matrix
embedding_layer.weight.requires_grad = False
return embedding_layer

def forward(self, inputs):
mask = (inputs <= self.oov_idx).float()
Expand Down
2 changes: 1 addition & 1 deletion fuxictr/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__="2.1.0"
__version__="2.1.1"
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="fuxictr",
version="2.1.0",
version="2.1.1",
author="fuxictr",
author_email="fuxictr@users.noreply.github.com",
description="A configurable, tunable, and reproducible library for CTR prediction",
Expand Down

0 comments on commit 26c4423

Please sign in to comment.