diff --git a/benchmark/run_expid.py b/benchmark/run_expid.py index 391f628..6e0b4ab 100644 --- a/benchmark/run_expid.py +++ b/benchmark/run_expid.py @@ -74,10 +74,10 @@ gc.collect() logging.info('******** Test evaluation ********') - test_gen = H5DataLoader(feature_map, stage='test', **params).make_iterator() test_result = {} - if test_gen: - test_result = model.evaluate(test_gen) + if params["test_data"]: + test_gen = H5DataLoader(feature_map, stage='test', **params).make_iterator() + test_result = model.evaluate(test_gen) result_filename = Path(args['config']).name.replace(".yaml", "") + '.csv' with open(result_filename, 'a+') as fw: diff --git a/fuxictr/preprocess/feature_processor.py b/fuxictr/preprocess/feature_processor.py index 1183f95..4cd317d 100644 --- a/fuxictr/preprocess/feature_processor.py +++ b/fuxictr/preprocess/feature_processor.py @@ -176,6 +176,8 @@ def fit_categorical_col(self, col, col_values, min_categr_count=1, num_buckets=1 self.feature_map.features[name]["embedding_dim"] = col["embedding_dim"] if "emb_output_dim" in col: self.feature_map.features[name]["emb_output_dim"] = col["emb_output_dim"] + if "pretrain_dim" in col: + self.feature_map.features[name]["pretrain_dim"] = col["pretrain_dim"] if "category_processor" not in col: tokenizer = Tokenizer(min_freq=min_categr_count, na_value=col.get("fill_na", ""), @@ -194,6 +196,7 @@ def fit_categorical_col(self, col, col_values, min_categr_count=1, num_buckets=1 logging.info("Loading pretrained embedding: " + name) self.feature_map.features[name]["pretrained_emb"] = "pretrained_emb.h5" self.feature_map.features[name]["freeze_emb"] = col.get("freeze_emb", True) + self.feature_map.features[name]["pretrain_usage"] = col.get("pretrain_usage", "init") tokenizer.load_pretrained_embedding(name, self.dtype_dict[name], col["pretrained_emb"], @@ -236,6 +239,8 @@ def fit_sequence_col(self, col, col_values, min_categr_count=1): self.feature_map.features[name]["embedding_dim"] = col["embedding_dim"] if "emb_output_dim" in col: self.feature_map.features[name]["emb_output_dim"] = col["emb_output_dim"] + if "pretrain_dim" in col: + self.feature_map.features[name]["pretrain_dim"] = col["pretrain_dim"] splitter = col.get("splitter") na_value = col.get("fill_na", "") max_len = col.get("max_len", 0) @@ -257,6 +262,7 @@ def fit_sequence_col(self, col, col_values, min_categr_count=1): logging.info("Loading pretrained embedding: " + name) self.feature_map.features[name]["pretrained_emb"] = "pretrained_emb.h5" self.feature_map.features[name]["freeze_emb"] = col.get("freeze_emb", True) + self.feature_map.features[name]["pretrain_usage"] = col.get("pretrain_usage", "init") tokenizer.load_pretrained_embedding(name, self.dtype_dict[name], col["pretrained_emb"], diff --git a/fuxictr/pytorch/layers/embeddings/__init__.py b/fuxictr/pytorch/layers/embeddings/__init__.py index b73c62d..15a3633 100644 --- a/fuxictr/pytorch/layers/embeddings/__init__.py +++ b/fuxictr/pytorch/layers/embeddings/__init__.py @@ -1,2 +1,2 @@ from .feature_embedding import * - +from .pretrained_embedding import * diff --git a/fuxictr/pytorch/layers/embeddings/feature_embedding.py b/fuxictr/pytorch/layers/embeddings/feature_embedding.py index 9a741ab..0032c10 100644 --- a/fuxictr/pytorch/layers/embeddings/feature_embedding.py +++ b/fuxictr/pytorch/layers/embeddings/feature_embedding.py @@ -17,10 +17,10 @@ import torch from torch import nn -import h5py import os import numpy as np from collections import OrderedDict +from .pretrained_embedding import PretrainedEmbedding from fuxictr.pytorch.torch_utils import get_initializer from fuxictr.pytorch import layers @@ -69,11 +69,11 @@ def __init__(self, for feature, feature_spec in self._feature_map.features.items(): if self.is_required(feature): if not (use_pretrain and use_sharing) and embedding_dim == 1: - feat_emb_dim = 1 # in case for LR + feat_dim = 1 # in case for LR if feature_spec["type"] == "sequence": self.feature_encoders[feature] = layers.MaskedSumPooling() else: - feat_emb_dim = feature_spec.get("embedding_dim", embedding_dim) + feat_dim = feature_spec.get("embedding_dim", embedding_dim) if feature_spec.get("feature_encoder", None): self.feature_encoders[feature] = self.get_feature_encoder(feature_spec["feature_encoder"]) @@ -83,31 +83,24 @@ def __init__(self, continue if feature_spec["type"] == "numeric": - self.embedding_layers[feature] = nn.Linear(1, feat_emb_dim, bias=False) - elif feature_spec["type"] == "categorical": - padding_idx = feature_spec.get("padding_idx", None) - embedding_matrix = nn.Embedding(feature_spec["vocab_size"], - feat_emb_dim, - padding_idx=padding_idx) - if use_pretrain and "pretrained_emb" in feature_spec: - embedding_matrix = self.load_pretrained_embedding(embedding_matrix, - feature_map, - feature, - freeze=feature_spec["freeze_emb"], - padding_idx=padding_idx) - self.embedding_layers[feature] = embedding_matrix - elif feature_spec["type"] == "sequence": - padding_idx = feature_spec.get("padding_idx", None) - embedding_matrix = nn.Embedding(feature_spec["vocab_size"], - feat_emb_dim, - padding_idx=padding_idx) + self.embedding_layers[feature] = nn.Linear(1, feat_dim, bias=False) + elif feature_spec["type"] in ["categorical", "sequence"]: if use_pretrain and "pretrained_emb" in feature_spec: - embedding_matrix = self.load_pretrained_embedding(embedding_matrix, - feature_map, - feature, - freeze=feature_spec["freeze_emb"], - padding_idx=padding_idx) - self.embedding_layers[feature] = embedding_matrix + pretrained_path = os.path.join(feature_map.data_dir, + feature_spec["pretrained_emb"]) + pretrain_dim = feature_spec.get("pretrain_dim", feat_dim) + pretrain_usage = feature_spec.get("pretrain_usage", "init") + self.embedding_layers[feature] = PretrainedEmbedding(feature, + feature_spec, + pretrained_path, + feat_dim, + pretrain_dim, + pretrain_usage) + else: + padding_idx = feature_spec.get("padding_idx", None) + self.embedding_layers[feature] = nn.Embedding(feature_spec["vocab_size"], + feat_dim, + padding_idx=padding_idx) self.reset_parameters() def get_feature_encoder(self, encoder): @@ -148,24 +141,6 @@ def is_required(self, feature): else: return True - def get_pretrained_embedding(self, pretrained_path, feature_name): - with h5py.File(pretrained_path, 'r') as hf: - embeddings = hf[feature_name][:] - return embeddings - - def load_pretrained_embedding(self, embedding_matrix, feature_map, feature_name, freeze=False, padding_idx=None): - pretrained_path = os.path.join(feature_map.data_dir, feature_map.features[feature_name]["pretrained_emb"]) - embeddings = self.get_pretrained_embedding(pretrained_path, feature_name) - if padding_idx is not None: - embeddings[padding_idx] = np.zeros(embeddings.shape[-1]) - assert embeddings.shape[-1] == embedding_matrix.embedding_dim, \ - "{}\'s embedding_dim is not correctly set to match its pretrained_emb shape".format(feature_name) - embeddings = torch.from_numpy(embeddings).float() - embedding_matrix.weight = torch.nn.Parameter(embeddings) - if freeze: - embedding_matrix.weight.requires_grad = False - return embedding_matrix - def dict2tensor(self, embedding_dict, feature_list=[], feature_source=[], feature_type=[], flatten_emb=False): if type(feature_source) != list: feature_source = [feature_source] @@ -214,7 +189,3 @@ def forward(self, inputs, feature_source=[], feature_type=[]): embeddings = self.feature_encoders[feature](embeddings) feature_emb_dict[feature] = embeddings return feature_emb_dict - - - - diff --git a/fuxictr/pytorch/layers/embeddings/pretrained_embedding.py b/fuxictr/pytorch/layers/embeddings/pretrained_embedding.py new file mode 100644 index 0000000..59adfa7 --- /dev/null +++ b/fuxictr/pytorch/layers/embeddings/pretrained_embedding.py @@ -0,0 +1,91 @@ +# ========================================================================= +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + + +import torch +from torch import nn +import h5py +import os +import numpy as np + + +class PretrainedEmbedding(nn.Module): + def __init__(self, + feature_name, + feature_spec, + pretrained_path, + embedding_dim, + pretrain_dim, + pretrain_usage="init"): + """ + Fusion pretrained embedding with ID embedding + :param: fusion_type: init/sum/concat + """ + super().__init__() + assert pretrain_usage in ["init", "sum", "concat"] + self.pretrain_usage = pretrain_usage + padding_idx = feature_spec.get("padding_idx", None) + embedding_matrix = nn.Embedding(feature_spec["vocab_size"], + pretrain_dim, + padding_idx=padding_idx) + self.pretrain_embedding = self.load_pretrained_embedding(embedding_matrix, + pretrained_path, + feature_name, + freeze=feature_spec["freeze_emb"], + padding_idx=padding_idx) + if pretrain_usage != "init": + self.id_embedding = nn.Embedding(feature_spec["vocab_size"], + embedding_dim, + padding_idx=padding_idx) + if pretrain_usage == "sum" and embedding_dim != pretrain_dim: + self.proj_1 = nn.Linear(pretrain_dim, embedding_dim) + else: + self.proj_1 = None + if pretrain_usage == "concat": + self.proj_2 = nn.Linear(pretrain_dim + embedding_dim, embedding_dim) + + def get_pretrained_embedding(self, pretrained_path, feature_name): + with h5py.File(pretrained_path, 'r') as hf: + embeddings = hf[feature_name][:] + return embeddings + + def load_pretrained_embedding(self, embedding_matrix, pretrained_path, feature_name, freeze=False, padding_idx=None): + embeddings = self.get_pretrained_embedding(pretrained_path, feature_name) + if padding_idx is not None: + embeddings[padding_idx] = np.zeros(embeddings.shape[-1]) + assert embeddings.shape[-1] == embedding_matrix.embedding_dim, \ + "{}\'s pretrain_dim is not correct.".format(feature_name) + embeddings = torch.from_numpy(embeddings).float() + embedding_matrix.weight = torch.nn.Parameter(embeddings) + if freeze: + embedding_matrix.weight.requires_grad = False + return embedding_matrix + + def forward(self, inputs): + pretrain_emb = self.pretrain_embedding(inputs) + if self.pretrain_usage == "init": + feature_emb = pretrain_emb + else: + id_emb = self.id_embedding(inputs) + if self.pretrain_usage == "sum": + if self.proj_1 is not None: + feature_emb = self.proj_1(pretrain_emb) + id_emb + else: + feature_emb = pretrain_emb + id_emb + if self.pretrain_usage == "concat": + feature_emb = torch.cat([pretrain_emb, id_emb], dim=-1) + feature_emb = self.proj_2(feature_emb) + return feature_emb diff --git a/fuxictr/version.py b/fuxictr/version.py index 9c83dcf..7f0f6d7 100644 --- a/fuxictr/version.py +++ b/fuxictr/version.py @@ -1 +1 @@ -__version__="2.0.4" +__version__="2.1.0" diff --git a/setup.py b/setup.py index 4664f40..43f64f9 100644 --- a/setup.py +++ b/setup.py @@ -5,9 +5,9 @@ setuptools.setup( name="fuxictr", - version="2.0.4", - author="xue-pai", - author_email="xue-pai@users.noreply.github.com", + version="2.1.0", + author="fuxictr", + author_email="fuxictr@users.noreply.github.com", description="A configurable, tunable, and reproducible library for CTR prediction", long_description=long_description, long_description_content_type="text/markdown",