Skip to content

Commit

Permalink
- Fix TransAct bug (#132)
Browse files Browse the repository at this point in the history
- Set polars <= 1.0.0
- Remove group_id from feature_map.json
  • Loading branch information
xpai committed Dec 24, 2024
1 parent 9d89a67 commit 562061c
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 16 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ Click-through rate (CTR) prediction is a critical task for various industrial ap
| 44 | DLP-KDD'19 | [BST](./model_zoo/BST) | [Behavior Sequence Transformer for E-commerce Recommendation in Alibaba](https://arxiv.org/abs/1905.06874) :triangular_flag_on_post:**Alibaba** | [:arrow_upper_right:](https://github.com/reczoo/BARS/tree/main/ranking/ctr/BST) | `torch` |
| 45 | CIKM'20 | [DMIN](./model_zoo/DMIN) | [Deep Multi-Interest Network for Click-through Rate Prediction](https://dl.acm.org/doi/10.1145/3340531.3412092) :triangular_flag_on_post:**Alibaba** | [:arrow_upper_right:](https://github.com/reczoo/BARS/tree/main/ranking/ctr/DMIN) | `torch` |
| 46 | AAAI'20 | [DMR](./model_zoo/DMR) | [Deep Match to Rank Model for Personalized Click-Through Rate Prediction](https://ojs.aaai.org/index.php/AAAI/article/view/5346) :triangular_flag_on_post:**Alibaba** | [:arrow_upper_right:](https://github.com/reczoo/BARS/tree/main/ranking/ctr/DMR) | `torch` |
| 47 | DLP-KDD'22 | [ETA](./model_zoo/ETA) | [Efficient Long Sequential User Data Modeling for Click-Through Rate Prediction](https://arxiv.org/abs/2209.12212) :triangular_flag_on_post:**Alibaba** | | `torch` |
| 48 | CIKM'22 | [SDIM](./model_zoo/SDIM) | [Sampling Is All You Need on Modeling Long-Term User Behaviors for CTR Prediction](https://arxiv.org/abs/2205.10249) :triangular_flag_on_post:**Meituan** | | `torch` |
| 47 | DLP-KDD'22 | [ETA](./model_zoo/LongCTR/ETA) | [Efficient Long Sequential User Data Modeling for Click-Through Rate Prediction](https://arxiv.org/abs/2209.12212) :triangular_flag_on_post:**Alibaba** | | `torch` |
| 48 | CIKM'22 | [SDIM](./model_zoo/LongCTR/SDIM) | [Sampling Is All You Need on Modeling Long-Term User Behaviors for CTR Prediction](https://arxiv.org/abs/2205.10249) :triangular_flag_on_post:**Meituan** | | `torch` |
| 49 | KDD'23 | [TransAct](./model_zoo/TransAct) | [TransAct: Transformer-based Realtime User Action Model for Recommendation at Pinterest](https://arxiv.org/abs/2306.00248) :triangular_flag_on_post:**Pinterest** | [:arrow_upper_right:](https://github.com/reczoo/BARS/tree/main/ranking/ctr/TransAct) | `torch` |
|<tr><th colspan=6 align="center">:open_file_folder: **Dynamic Weight Network**</th></tr>|
| 50 | NeurIPS'22 | [APG](./model_zoo/APG) | [APG: Adaptive Parameter Generation Network for Click-Through Rate Prediction](https://arxiv.org/abs/2203.16218) :triangular_flag_on_post:**Alibaba** | [:arrow_upper_right:](https://github.com/reczoo/BARS/tree/main/ranking/ctr/APG) | `torch` |
Expand Down
4 changes: 1 addition & 3 deletions fuxictr/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def load(self, json_file, params):
self.labels = feature_map.get("labels", [])
self.total_features = feature_map.get("total_features", 0)
self.input_length = feature_map.get("input_length", 0)
self.group_id = feature_map.get("group_id", None)
self.group_id = params.get("group_id", None)
self.default_emb_dim = params.get("embedding_dim", None)
self.features = OrderedDict((k, v) for x in feature_map["features"] for k, v in x.items())
self.num_fields = self.get_num_fields()
Expand Down Expand Up @@ -74,8 +74,6 @@ def save(self, json_file):
feature_map["total_features"] = self.total_features
feature_map["input_length"] = self.input_length
feature_map["labels"] = self.labels
if self.group_id is not None:
feature_map["group_id"] = self.group_id
feature_map["features"] = [{k: v} for k, v in self.features.items()]
with open(json_file, "w") as fd:
json.dump(feature_map, fd, indent=4)
Expand Down
8 changes: 4 additions & 4 deletions model_zoo/LongCTR/ETA/config/dataset_config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
### Tiny data for tests only
tiny_seq:
data_root: ../../data/
data_root: ../../../data/
data_format: npz
train_data: ../../data/tiny_seq/train.npz
valid_data: ../../data/tiny_seq/valid.npz
test_data: ../../data/tiny_seq/test.npz
train_data: ../../../data/tiny_seq/train.npz
valid_data: ../../../data/tiny_seq/valid.npz
test_data: ../../../data/tiny_seq/test.npz

8 changes: 4 additions & 4 deletions model_zoo/LongCTR/SDIM/config/dataset_config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
### Tiny data for tests only
tiny_seq:
data_root: ../../data/
data_root: ../../../data/
data_format: npz
train_data: ../../data/tiny_seq/train.npz
valid_data: ../../data/tiny_seq/valid.npz
test_data: ../../data/tiny_seq/test.npz
train_data: ../../../data/tiny_seq/train.npz
valid_data: ../../../data/tiny_seq/valid.npz
test_data: ../../../data/tiny_seq/test.npz

4 changes: 2 additions & 2 deletions model_zoo/TransAct/src/TransAct.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def __init__(self,
else embedding_dim
)
transformer_in_dim = seq_emb_dim + target_emb_dim
seq_out_dim += (first_k_cols + int(concat_max_pool)) * transformer_in_dim
self.transformer_encoders.append(
TransActTransformer(transformer_in_dim,
dim_feedforward=dim_feedforward,
Expand All @@ -134,7 +133,8 @@ def __init__(self,
first_k_cols=first_k_cols,
concat_max_pool=concat_max_pool)
)
dcn_in_dim = feature_map.sum_emb_out_dim() + seq_out_dim - seq_emb_dim
seq_out_dim += (first_k_cols + int(concat_max_pool)) * transformer_in_dim - seq_emb_dim
dcn_in_dim = feature_map.sum_emb_out_dim() + seq_out_dim
self.crossnet = CrossNetV2(dcn_in_dim, dcn_cross_layers)
self.parallel_dnn = MLP_Block(input_dim=dcn_in_dim,
output_dim=None, # output hidden layer
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ numpy
h5py
tqdm
pyarrow
polars
polars<=1.0.0

0 comments on commit 562061c

Please sign in to comment.