diff --git a/LICENSE b/LICENSE index 4720bde..5b0a09e 100644 --- a/LICENSE +++ b/LICENSE @@ -218,7 +218,6 @@ For layers.ScaledDotProductAttention: Copyright (C) 2018 pengshuang@Github - ----------------------------------------------------------------------- For models.DCNv2.CrossNetMix: @@ -235,3 +234,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + ----------------------------------------------------------------------- + + For model_zoo.GDCN: + Copyright (C) 2023 sdilbaz@github + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index dec29a9..6a7fc43 100644 --- a/README.md +++ b/README.md @@ -72,23 +72,24 @@ Click-through rate (CTR) prediction is a critical task for many industrial appli | 35 | KDD'21 | [AOANet](./model_zoo/AOANet) | [Architecture and Operation Adaptive Network for Online Recommendations](https://dl.acm.org/doi/10.1145/3447548.3467133) :triangular_flag_on_post:**Didi Chuxing** | [:arrow_upper_right:](https://github.com/openbenchmark/BARS/tree/main/ctr_prediction/benchmarks/AOANet) | `torch` | | 36 | AAAI'23 | [FinalMLP](./model_zoo/FinalMLP) | [FinalMLP: An Enhanced Two-Stream MLP Model for CTR Prediction](https://arxiv.org/abs/2304.00902) :triangular_flag_on_post:**Huawei** | [:arrow_upper_right:](https://github.com/openbenchmark/BARS/tree/main/ctr_prediction/benchmarks/FinalMLP) | `torch` | | 37 | SIGIR'23 | [FINAL](./model_zoo/FINAL) | FINAL: Factorized Interaction Layer for CTR Prediction :triangular_flag_on_post:**Huawei** | [:arrow_upper_right:](https://github.com/openbenchmark/BARS/tree/main/ctr_prediction/benchmarks/FINAL) | `torch` | +| 38 | CIKM'23 | [GDCN](./model_zoo/GDCN) | [Towards Deeper, Lighter and Interpretable Cross Network for CTR Prediction](https://dl.acm.org/doi/pdf/10.1145/3583780.3615089) :triangular_flag_on_post:**Microsoft** | [:arrow_upper_right:](https://github.com/openbenchmark/BARS/tree/main/ctr_prediction/benchmarks/FinalMLP) | `torch` | |:open_file_folder: **Behavior Sequence Modeling**| -| 38 | KDD'18 | [DIN](./model_zoo/DIN) | [Deep Interest Network for Click-Through Rate Prediction](https://www.kdd.org/kdd2018/accepted-papers/view/deep-interest-network-for-click-through-rate-prediction) :triangular_flag_on_post:**Alibaba** | [:arrow_upper_right:](https://github.com/openbenchmark/BARS/tree/main/ctr_prediction/benchmarks/DIN) | `torch` | -| 39 | AAAI'19 | [DIEN](./model_zoo/DIEN) | [Deep Interest Evolution Network for Click-Through Rate Prediction](https://arxiv.org/abs/1809.03672) :triangular_flag_on_post:**Alibaba** | [:arrow_upper_right:](https://github.com/openbenchmark/BARS/tree/main/ctr_prediction/benchmarks/DIEN) | `torch` | -| 40 | DLP-KDD'19 | [BST](./model_zoo/BST) | [Behavior Sequence Transformer for E-commerce Recommendation in Alibaba](https://arxiv.org/abs/1905.06874) :triangular_flag_on_post:**Alibaba** | [:arrow_upper_right:](https://github.com/openbenchmark/BARS/tree/main/ctr_prediction/benchmarks/BST) | `torch` | -| 41 | CIKM'20 | [DMIN](./model_zoo/DMIN) | [Deep Multi-Interest Network for Click-through Rate Prediction](https://dl.acm.org/doi/10.1145/3340531.3412092) :triangular_flag_on_post:**Alibaba** | | `torch` | -| 42 | AAAI'20 | [DMR](./model_zoo/DMR) | [Deep Match to Rank Model for Personalized Click-Through Rate Prediction](https://ojs.aaai.org/index.php/AAAI/article/view/5346) :triangular_flag_on_post:**Alibaba** | | `torch` | -| 43 | Arxiv'21 | [ETA](./model_zoo/ETA) | [End-to-End User Behavior Retrieval in Click-Through RatePrediction Model](https://arxiv.org/abs/2108.04468) :triangular_flag_on_post:**Alibaba** | | `torch` | -| 44 | CIKM'22 | [SDIM](./model_zoo/SDIM) | [Sampling Is All You Need on Modeling Long-Term User Behaviors for CTR Prediction](https://arxiv.org/abs/2205.10249) :triangular_flag_on_post:**Meituan** | | `torch` | +| 39 | KDD'18 | [DIN](./model_zoo/DIN) | [Deep Interest Network for Click-Through Rate Prediction](https://www.kdd.org/kdd2018/accepted-papers/view/deep-interest-network-for-click-through-rate-prediction) :triangular_flag_on_post:**Alibaba** | [:arrow_upper_right:](https://github.com/openbenchmark/BARS/tree/main/ctr_prediction/benchmarks/DIN) | `torch` | +| 40 | AAAI'19 | [DIEN](./model_zoo/DIEN) | [Deep Interest Evolution Network for Click-Through Rate Prediction](https://arxiv.org/abs/1809.03672) :triangular_flag_on_post:**Alibaba** | [:arrow_upper_right:](https://github.com/openbenchmark/BARS/tree/main/ctr_prediction/benchmarks/DIEN) | `torch` | +| 41 | DLP-KDD'19 | [BST](./model_zoo/BST) | [Behavior Sequence Transformer for E-commerce Recommendation in Alibaba](https://arxiv.org/abs/1905.06874) :triangular_flag_on_post:**Alibaba** | [:arrow_upper_right:](https://github.com/openbenchmark/BARS/tree/main/ctr_prediction/benchmarks/BST) | `torch` | +| 42 | CIKM'20 | [DMIN](./model_zoo/DMIN) | [Deep Multi-Interest Network for Click-through Rate Prediction](https://dl.acm.org/doi/10.1145/3340531.3412092) :triangular_flag_on_post:**Alibaba** | | `torch` | +| 43 | AAAI'20 | [DMR](./model_zoo/DMR) | [Deep Match to Rank Model for Personalized Click-Through Rate Prediction](https://ojs.aaai.org/index.php/AAAI/article/view/5346) :triangular_flag_on_post:**Alibaba** | | `torch` | +| 44 | Arxiv'21 | [ETA](./model_zoo/ETA) | [End-to-End User Behavior Retrieval in Click-Through RatePrediction Model](https://arxiv.org/abs/2108.04468) :triangular_flag_on_post:**Alibaba** | | `torch` | +| 45 | CIKM'22 | [SDIM](./model_zoo/SDIM) | [Sampling Is All You Need on Modeling Long-Term User Behaviors for CTR Prediction](https://arxiv.org/abs/2205.10249) :triangular_flag_on_post:**Meituan** | | `torch` | |:open_file_folder: **Dynamic Weight Network**| -| 45 | NeurIPS'22 | [APG](./model_zoo/APG) | [APG: Adaptive Parameter Generation Network for Click-Through Rate Prediction](https://arxiv.org/abs/2203.16218) :triangular_flag_on_post:**Alibaba** | | `torch` | -| 46 | Arxiv'23 | [PPNet](./model_zoo/PEPNet) | [PEPNet: Parameter and Embedding Personalized Network for Infusing with Personalized Prior Information](https://arxiv.org/abs/2302.01115) :triangular_flag_on_post:**KuaiShou** | | `torch` | +| 46 | NeurIPS'22 | [APG](./model_zoo/APG) | [APG: Adaptive Parameter Generation Network for Click-Through Rate Prediction](https://arxiv.org/abs/2203.16218) :triangular_flag_on_post:**Alibaba** | | `torch` | +| 47 | Arxiv'23 | [PPNet](./model_zoo/PEPNet) | [PEPNet: Parameter and Embedding Personalized Network for Infusing with Personalized Prior Information](https://arxiv.org/abs/2302.01115) :triangular_flag_on_post:**KuaiShou** | | `torch` | |:open_file_folder: **Multi-Task Modeling**| -| 47 | MachineLearn'97 | [SharedBottom](./model_zoo/multitask/SharedBottom) | [Multitask Learning](https://link.springer.com/article/10.1023/A:1007379606734) | | `torch` | -| 48 | KDD'18 | [MMoE](./model_zoo/multitask/MMOE) | [Modeling Task Relationships in Multi-task Learning with Multi-Gate Mixture-of-Experts](https://dl.acm.org/doi/pdf/10.1145/3219819.3220007) :triangular_flag_on_post:**Google** | | `torch` | -| 49 | KDD'18 | [PLE](./model_zoo/multitask/PLE) | [Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations](https://dl.acm.org/doi/10.1145/3383313.3412236) :triangular_flag_on_post:**Tencent** | | `torch` | +| 48 | MachineLearn'97 | [SharedBottom](./model_zoo/multitask/SharedBottom) | [Multitask Learning](https://link.springer.com/article/10.1023/A:1007379606734) | | `torch` | +| 49 | KDD'18 | [MMoE](./model_zoo/multitask/MMOE) | [Modeling Task Relationships in Multi-task Learning with Multi-Gate Mixture-of-Experts](https://dl.acm.org/doi/pdf/10.1145/3219819.3220007) :triangular_flag_on_post:**Google** | | `torch` | +| 50 | KDD'18 | [PLE](./model_zoo/multitask/PLE) | [Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations](https://dl.acm.org/doi/10.1145/3383313.3412236) :triangular_flag_on_post:**Tencent** | | `torch` | |:open_file_folder: **Multi-Domain Modeling**| -| 50 | Arxiv'23 | PEPNet | [PEPNet: Parameter and Embedding Personalized Network for Infusing with Personalized Prior Information](https://arxiv.org/abs/2302.01115) :triangular_flag_on_post:**KuaiShou** | | `torch` | +| 51 | Arxiv'23 | PEPNet | [PEPNet: Parameter and Embedding Personalized Network for Infusing with Personalized Prior Information](https://arxiv.org/abs/2302.01115) :triangular_flag_on_post:**KuaiShou** | | `torch` | + :point_right: See [reusable dataset splits for CTR prediction](https://openbenchmark.github.io/BARS/datasets/README.html). diff --git a/model_zoo/GDCN/README.md b/model_zoo/GDCN/README.md new file mode 100644 index 0000000..b28adac --- /dev/null +++ b/model_zoo/GDCN/README.md @@ -0,0 +1,93 @@ +## GDCN + +| [Overview](#Overview) | [Configuration](#Configuration) | [Implementation](#Implementation) | [Discussion](#Discussion) | +| :--: | :--: | :--: | :--: | + +### Overview + +GDCN is a CTR prediction model that learns explicit and bounded-degree cross features. The model is published in the following paper: + ++ [Towards Deeper, Lighter and Interpretable Cross Network for +CTR Prediction](https://dl.acm.org/doi/pdf/10.1145/3583780.3615089), in CIKM 2023. + +**Key components:** + ++ *CrossNet*: The component provides explicit feature crossing with bounded degree. + + $$x_{l+1} = x_0x_l^Tw + b + x_l$$ + ++ *Dynamic embedding size*: It provides a formula to compute the embedding size of each feature field. + + $$emb\_dim = 6\times(vocab\_size)^{1/4}$$ + +### Configuration + +The `model_config.yaml` file contains all the model hyper-parameters as follows. + +| Params | Type | Default | Description | +| ---------------------- | --------------- | ------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| model | str | "DCN" | model name,  which should be same with model class name | +| dataset_id | str | "TBD" | dataset_id to be determined | +| loss | str | "binary_crossentropy" | loss function | +| metrics | list | ['logloss', 'AUC'] | a list of metrics for evaluation | +| task | str | "binary_classification" | task type supported: ```"regression"```, ```"binary_classification"``` | +| optimizer | str | "adam" | optimizer used for training | +| learning_rate | float | 1.0e-3 | learning rate | +| embedding_regularizer | float/str | 0 | regularization weight for embedding matrix: L2 regularization is applied by default. Other optional examples: ```"l2(1.e-3)"```, ```"l1(1.e-3)"```, ```"l1_l2(1.e-3, 1.e-3)"```. | +| net_regularizer | float/str | 0 | regularization weight for network parameters: L2 regularization is applied by default. Other optional examples: ```"l2(1.e-3)"```, ```"l1(1.e-3)"```, ```"l1_l2(1.e-3, 1.e-3)"```. | +| batch_size | int | 10000 | batch size, usually a large number for CTR prediction task | +| embedding_dim | int | 32 | embedding dimension of features. Note that field-wise embedding_dim can be specified in ```feature_specs```. | +| dnn_hidden_units | list | [1024, 512, 256] | hidden units in DNN | +| dnn_activations | str/list | "relu" | activation function in DNN. Particularly, layer-wise activations can be specified as a list, e.g., ["relu", "leakyrelu", "sigmoid"] | +| num_cross_layers | int | 3 | number of cross layers in CrossNet | +| net_dropout | float | 0 | dropout rate in DNN | +| batch_norm | bool | False | whether using BN in DNN | +| epochs | int | 100 | the max number of epochs for training, which can early stop via monitor metrics. | +| shuffle | bool | True | whether shuffle the data samples for each epoch of training | +| seed | int | 20222023 | the random seed used for reproducibility | +| monitor | str/dict | {'AUC': 1, 'logloss': -1} | the monitor metrics for early stopping. It supports a single metric, e.g., ```"AUC"```. It also supports multiple metrics using a dict, e.g., {"AUC": 2, "logloss": -1} means ```2*AUC - logloss```. | +| monitor_mode | str | 'max' | ```"max"``` means that the higher the better, while ```"min"``` denotes that the lower the better. | +| model_root | str | './checkpoints/' | the dir to save model checkpoints and running logs | +| num_workers | int | 3 | the number of workers for data loader | +| verbose | int | 1 | 0 for salience while 1 for verbose logging with tqdm | +| early_stop_patience | int | 2 | training is stopped when monitor metric fails to become better for ```early_stop_patience=2```consective evaluation intervals. | +| pickle_feature_encoder | bool | True | whether to pickle the feature encoder during preprocessing. It is used when input ```data_format="csv"```. | +| save_best_only | bool | True | whether to save the best model checkpoint only | +| eval_steps | int/None | None | evaluate the model on validation data every ```eval_steps```. By default, ```None``` means evaluation every epoch. | +| debug_mode | bool | False | used for code testing. When setting it to ```True```, the ```experiment_id``` will be randomly generated to avoid interleaving when running multiple processes for parameter tunning by ```run_param_tuner.py```. | +| group_id | None (optional) | None | required for metrics like ```gAUC```, ```NDCG```. | +| use_features | None (optional) | None | used for feature selection, i.e., only selecting an ordered subset of features as model input | +| feature_specs | dict (optional) | None | used for specifying field-wise configurations, such as ```embedding_dim```, ```feature_encoder``` for a specific field. | + + +### Implementation + +**Code structure:** + +``` +├── config # 配置文件夹 +│ ├── dataset_config.yaml # 数据集配置文件 +│ └── model_config.yaml # 模型配置文件 +├── src # 模型代码文件夹 +│ └── GDCN.py # 模型代码 +├── fuxictr_version.py # fuxictr加载及版本检查文件 +├── README.md # 使用说明 +├── requirements.txt # 依赖文件 +└── run_expid.py # 执行脚本文件 +``` + +**Requirements:** + +The model is tested with the following dependencies. + ++ fuxictr==2.0.0 + ++ pytorch==1.11 + +**Get started:** + +Running the model on the tiny data: + +``` +python run_expid.py --expid GDCNP_test --gpu 0 +``` diff --git a/model_zoo/GDCN/config/dataset_config.yaml b/model_zoo/GDCN/config/dataset_config.yaml new file mode 100644 index 0000000..c45414b --- /dev/null +++ b/model_zoo/GDCN/config/dataset_config.yaml @@ -0,0 +1,8 @@ +### Tiny data for tests only +tiny_h5: + data_root: ../../data/ + data_format: h5 + train_data: ../../data/tiny_h5/train.h5 + valid_data: ../../data/tiny_h5/valid.h5 + test_data: ../../data/tiny_h5/test.h5 + diff --git a/model_zoo/GDCN/config/model_config.yaml b/model_zoo/GDCN/config/model_config.yaml new file mode 100644 index 0000000..f2cc6ef --- /dev/null +++ b/model_zoo/GDCN/config/model_config.yaml @@ -0,0 +1,83 @@ +Base: + model_root: './checkpoints/' + num_workers: 3 + verbose: 1 + early_stop_patience: 2 + pickle_feature_encoder: True + save_best_only: True + eval_steps: null + debug_mode: False + group_id: null + use_features: null + feature_specs: null + feature_config: null + +GDCNP_test: + model: GDCNP + dataset_id: tiny_h5 + loss: 'binary_crossentropy' + metrics: ['logloss', 'AUC'] + task: binary_classification + optimizer: adam + learning_rate: 1.0e-3 + embedding_regularizer: 1.e-8 + net_regularizer: 0 + batch_size: 128 + embedding_dim: 4 + dnn_hidden_units: [64, 32] + dnn_activations: relu + crossing_layers: 3 + net_dropout: 0 + batch_norm: False + epochs: 1 + shuffle: True + seed: 2019 + monitor: 'AUC' + monitor_mode: 'max' + +GDCNS_test: + model: GDCNS + dataset_id: tiny_h5 + loss: 'binary_crossentropy' + metrics: ['logloss', 'AUC'] + task: binary_classification + optimizer: adam + learning_rate: 1.0e-3 + embedding_regularizer: 1.e-8 + net_regularizer: 0 + batch_size: 128 + embedding_dim: 4 + dnn_hidden_units: [64, 32] + dnn_activations: relu + crossing_layers: 3 + net_dropout: 0 + batch_norm: False + epochs: 1 + shuffle: True + seed: 2019 + monitor: 'AUC' + monitor_mode: 'max' + +GDCN_default: # This is a config template + model: GDCNP + dataset_id: TBD + loss: 'binary_crossentropy' + metrics: ['logloss', 'AUC'] + task: binary_classification + optimizer: adam + learning_rate: 1.0e-3 + embedding_regularizer: 0 + net_regularizer: 0 + batch_size: 10000 + embedding_dim: 32 + dnn_hidden_units: [1024, 512, 256] + dnn_activations: relu + num_cross_layers: 3 + net_dropout: 0 + batch_norm: False + epochs: 100 + shuffle: True + seed: 20222023 + monitor: {'AUC': 1, 'logloss': -1} + monitor_mode: 'max' + diff --git a/model_zoo/GDCN/fuxictr_version.py b/model_zoo/GDCN/fuxictr_version.py new file mode 100644 index 0000000..666e5f0 --- /dev/null +++ b/model_zoo/GDCN/fuxictr_version.py @@ -0,0 +1,3 @@ +# pip install -U fuxictr +import fuxictr +assert fuxictr.__version__ >= "2.0.0" diff --git a/model_zoo/GDCN/run_expid.py b/model_zoo/GDCN/run_expid.py new file mode 100644 index 0000000..2d44f3b --- /dev/null +++ b/model_zoo/GDCN/run_expid.py @@ -0,0 +1,88 @@ +# ========================================================================= +# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + + +import os +os.chdir(os.path.dirname(os.path.realpath(__file__))) +import sys +import logging +import fuxictr_version +from fuxictr import datasets +from datetime import datetime +from fuxictr.utils import load_config, set_logger, print_to_json, print_to_list +from fuxictr.features import FeatureMap +from fuxictr.pytorch.torch_utils import seed_everything +from fuxictr.pytorch.dataloaders import H5DataLoader +from fuxictr.preprocess import FeatureProcessor, build_dataset +import src as model_zoo +import gc +import argparse +import os +from pathlib import Path + + +if __name__ == '__main__': + ''' Usage: python run_expid.py --config {config_dir} --expid {experiment_id} --gpu {gpu_device_id} + ''' + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, default='./config/', help='The config directory.') + parser.add_argument('--expid', type=str, default='DeepFM_test', help='The experiment id to run.') + parser.add_argument('--gpu', type=int, default=-1, help='The gpu index, -1 for cpu') + args = vars(parser.parse_args()) + + experiment_id = args['expid'] + params = load_config(args['config'], experiment_id) + params['gpu'] = args['gpu'] + set_logger(params) + logging.info("Params: " + print_to_json(params)) + seed_everything(seed=params['seed']) + + data_dir = os.path.join(params['data_root'], params['dataset_id']) + feature_map_json = os.path.join(data_dir, "feature_map.json") + if params["data_format"] == "csv": + # Build feature_map and transform h5 data + feature_encoder = FeatureProcessor(**params) + params["train_data"], params["valid_data"], params["test_data"] = \ + build_dataset(feature_encoder, **params) + feature_map = FeatureMap(params['dataset_id'], data_dir) + feature_map.load(feature_map_json, params) + logging.info("Feature specs: " + print_to_json(feature_map.features)) + + model_class = getattr(model_zoo, params['model']) + model = model_class(feature_map, **params) + model.count_parameters() # print number of parameters used in model + + train_gen, valid_gen = H5DataLoader(feature_map, stage='train', **params).make_iterator() + model.fit(train_gen, validation_data=valid_gen, **params) + + logging.info('****** Validation evaluation ******') + valid_result = model.evaluate(valid_gen) + del train_gen, valid_gen + gc.collect() + + logging.info('******** Test evaluation ********') + test_gen = H5DataLoader(feature_map, stage='test', **params).make_iterator() + test_result = {} + if test_gen: + test_result = model.evaluate(test_gen) + + result_filename = Path(args['config']).name.replace(".yaml", "") + '.csv' + with open(result_filename, 'a+') as fw: + fw.write(' {},[command] python {},[exp_id] {},[dataset_id] {},[train] {},[val] {},[test] {}\n' \ + .format(datetime.now().strftime('%Y%m%d-%H%M%S'), + ' '.join(sys.argv), experiment_id, params['dataset_id'], + "N.A.", print_to_list(valid_result), print_to_list(test_result))) + diff --git a/model_zoo/GDCN/src/GDCN.py b/model_zoo/GDCN/src/GDCN.py new file mode 100644 index 0000000..d953c9f --- /dev/null +++ b/model_zoo/GDCN/src/GDCN.py @@ -0,0 +1,145 @@ +# ========================================================================= +# Copyright (C) 2023 sdilbaz@github +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +import torch +from torch import nn +from fuxictr.pytorch.models import BaseModel +from fuxictr.pytorch.layers import FeatureEmbedding, MLP_Block + + +class GDCNP(BaseModel): + def __init__(self, + feature_map, + model_id="GDCNP", + gpu=-1, + learning_rate=1e-3, + embedding_dim=10, + dnn_hidden_units=[], + dnn_activations="ReLU", + num_cross_layers=3, + net_dropout=0, + batch_norm=False, + embedding_regularizer=None, + net_regularizer=None, + **kwargs): + super(GDCNP, self).__init__(feature_map, + model_id=model_id, + gpu=gpu, + embedding_regularizer=embedding_regularizer, + net_regularizer=net_regularizer, + **kwargs) + self.embedding_layer = FeatureEmbedding(feature_map, embedding_dim) + input_dim = feature_map.sum_emb_out_dim() + self.dnn = MLP_Block(input_dim=input_dim, + output_dim=None, # output hidden layer + hidden_units=dnn_hidden_units, + hidden_activations=dnn_activations, + output_activation=None, + dropout_rates=net_dropout, + batch_norm=batch_norm) \ + if dnn_hidden_units else None # in case of only crossing net used + self.cross_net = GateCorssLayer(input_dim, num_cross_layers) + self.fc = torch.nn.Linear(dnn_hidden_units[-1] + input_dim, 1) + + self.compile(kwargs["optimizer"], kwargs["loss"], learning_rate) + self.reset_parameters() + self.model_to_device() + + def forward(self, inputs): + X = self.get_inputs(inputs) + feature_emb = self.embedding_layer(X, flatten_emb=True) + cross_cn = self.cross_net(feature_emb) + cross_mlp = self.dnn(feature_emb) + y_pred = self.fc(torch.cat([cross_cn, cross_mlp], dim=1)) + y_pred = self.output_activation(y_pred) + return_dict = {"y_pred": y_pred} + return return_dict + +class GDCNS(BaseModel): + def __init__(self, + feature_map, + model_id="GDCNS", + gpu=-1, + learning_rate=1e-3, + embedding_dim=10, + dnn_hidden_units=[], + dnn_activations="ReLU", + num_cross_layers=3, + net_dropout=0, + batch_norm=False, + embedding_regularizer=None, + net_regularizer=None, + **kwargs): + super(GDCNS, self).__init__(feature_map, + model_id=model_id, + gpu=gpu, + embedding_regularizer=embedding_regularizer, + net_regularizer=net_regularizer, + **kwargs) + self.embedding_layer = FeatureEmbedding(feature_map, embedding_dim) + input_dim = feature_map.sum_emb_out_dim() + self.dnn = MLP_Block(input_dim=input_dim, + output_dim=1, # output hidden layer + hidden_units=dnn_hidden_units, + hidden_activations=dnn_activations, + output_activation=None, + dropout_rates=net_dropout, + batch_norm=batch_norm) \ + if dnn_hidden_units else None # in case of only crossing net used + self.cross_net = GateCorssLayer(input_dim, num_cross_layers) + + self.compile(kwargs["optimizer"], kwargs["loss"], learning_rate) + self.reset_parameters() + self.model_to_device() + + def forward(self, inputs): + X = self.get_inputs(inputs) + feature_emb = self.embedding_layer(X, flatten_emb=True) + cross_cn = self.cross_net(feature_emb) + y_pred = self.dnn(cross_cn) + y_pred = self.output_activation(y_pred) + return_dict = {"y_pred": y_pred} + return return_dict + +class GateCorssLayer(nn.Module): + # The core structure: gated corss layer. + def __init__(self, input_dim, cn_layers=3): + super().__init__() + + self.cn_layers = cn_layers + + self.w = nn.ModuleList([ + nn.Linear(input_dim, input_dim, bias=False) for _ in range(cn_layers) + ]) + self.wg = nn.ModuleList([ + nn.Linear(input_dim, input_dim, bias=False) for _ in range(cn_layers) + ]) + + self.b = nn.ParameterList([nn.Parameter( + torch.zeros((input_dim,))) for _ in range(cn_layers)]) + + for i in range(cn_layers): + nn.init.uniform_(self.b[i].data) + + self.activation = nn.Sigmoid() + + def forward(self, x): + x0 = x + for i in range(self.cn_layers): + xw = self.w[i](x) # Feature Crossing + xg = self.activation(self.wg[i](x)) # Information Gate + x = x0 * (xw + self.b[i]) * xg + x + return x diff --git a/model_zoo/GDCN/src/__init__.py b/model_zoo/GDCN/src/__init__.py new file mode 100644 index 0000000..37caeaa --- /dev/null +++ b/model_zoo/GDCN/src/__init__.py @@ -0,0 +1,4 @@ +from .GDCN import GDCNP, GDCNS + + +