Skip to content

Commit

Permalink
fix some bugs
Browse files Browse the repository at this point in the history
- Fix model save error when using acc metric model save issue
- Fix model error when all feature columns are dense
  • Loading branch information
shenweichen authored Jun 19, 2022
2 parents b4d8181 + 300f115 commit 2cd84f3
Show file tree
Hide file tree
Showing 14 changed files with 114 additions and 96 deletions.
22 changes: 14 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,25 @@ jobs:
timeout-minutes: 120
strategy:
matrix:
python-version: [3.6,3.7]
torch-version: [1.1.0,1.2.0,1.3.0,1.4.0,1.5.0,1.6.0,1.7.0,1.8.1]
python-version: [3.6,3.7,3.8]
torch-version: [1.1.0,1.2.0,1.3.0,1.4.0,1.5.0,1.6.0,1.7.1,1.8.1,1.9.0,1.10.2,1.11.0]

# exclude:
# - python-version: 3.5
# tf-version: 1.1.0
exclude:
- python-version: 3.6
torch-version: 1.11.0
- python-version: 3.8
torch-version: 1.1.0
- python-version: 3.8
torch-version: 1.2.0
- python-version: 3.8
torch-version: 1.3.0

steps:

- uses: actions/checkout@v1
- uses: actions/checkout@v3

- name: Setup python environment
uses: actions/setup-python@v1
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -47,7 +53,7 @@ jobs:
pip install -q sklearn
pytest --cov=deepctr_torch --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1.0.2
uses: codecov/codecov-action@v3.1.0
with:
token: ${{secrets.CODECOV_TOKEN}}
file: ./coverage.xml
Expand Down
58 changes: 21 additions & 37 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,34 +47,19 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St

## DisscussionGroup & Related Projects

<html>
<table style="margin-left: 20px; margin-right: auto;">
<tr>
<td>
公众号:<b>浅梦学习笔记</b><br><br>
<a href="https://github.com/shenweichen/deepctr-torch">
<img align="center" src="./docs/pics/code.png" />
</a>
</td>
<td>
微信:<b>deepctrbot</b><br><br>
<a href="https://github.com/shenweichen/deepctr-torch">
<img align="center" src="./docs/pics/deepctrbot.png" />
</a>
</td>
<td>
<ul>
<li><a href="https://github.com/shenweichen/AlgoNotes">AlgoNotes</a></li>
<li><a href="https://github.com/shenweichen/DeepCTR">DeepCTR</a></li>
<li><a href="https://github.com/shenweichen/DeepMatch">DeepMatch</a></li>
<li><a href="https://github.com/shenweichen/GraphEmbedding">GraphEmbedding</a></li>
</ul>
</td>
</tr>
</table>
</html>
- [Github Discussions](https://github.com/shenweichen/DeepCTR/discussions)
- Wechat Discussions

|公众号:浅梦学习笔记|微信:deepctrbot|学习小组 [加入](https://t.zsxq.com/026UJEuzv) [主题集合](https://mp.weixin.qq.com/mp/appmsgalbum?__biz=MjM5MzY4NzE3MA==&action=getalbum&album_id=1361647041096843265&scene=126#wechat_redirect)|
|:--:|:--:|:--:|
| [![公众号](./docs/pics/code.png)](https://github.com/shenweichen/AlgoNotes)| [![微信](./docs/pics/deepctrbot.png)](https://github.com/shenweichen/AlgoNotes)|[![学习小组](./docs/pics/planet_github.png)](https://t.zsxq.com/026UJEuzv)|

- Related Projects

- [AlgoNotes](https://github.com/shenweichen/AlgoNotes)
- [DeepCTR](https://github.com/shenweichen/DeepCTR)
- [DeepMatch](https://github.com/shenweichen/DeepMatch)
- [GraphEmbedding](https://github.com/shenweichen/GraphEmbedding)

## Main Contributors([welcome to join us!](./CONTRIBUTING.md))

Expand All @@ -84,59 +69,58 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St
<td>
​ <a href="https://github.com/shenweichen"><img width="70" height="70" src="https://github.com/shenweichen.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/shenweichen">Shen Weichen</a> ​
<p>Core Dev<br> Zhejiang Unversity <br> <br> </p>​
<p> Alibaba Group </p>​
</td>
<td>
​ <a href="https://github.com/zanshuxun"><img width="70" height="70" src="https://github.com/zanshuxun.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/zanshuxun">Zan Shuxun</a>
<p>Core Dev<br> Beijing University <br> of Posts and <br> Telecommunications</p>​
<p> Alibaba Group </p>​
</td>
<td>
<a href="https://github.com/weberrr"><img width="70" height="70" src="https://github.com/weberrr.png?s=40" alt="pic"></a><br>
<a href="https://github.com/weberrr">Wang Ze</a> ​
<p>Core Dev<br> Beihang University <br> <br> </p>​
<p> Meituan </p>​
</td>
<td>
​ <a href="https://github.com/wutongzhang"><img width="70" height="70" src="https://github.com/wutongzhang.png?s=40" alt="pic"></a><br>
<a href="https://github.com/wutongzhang">Zhang Wutong</a>
<p>Core Dev<br> Beijing University <br> of Posts and <br> Telecommunications</p>​
<p> Tencent </p>​
</td>
<td>
​ <a href="https://github.com/ZhangYuef"><img width="70" height="70" src="https://github.com/ZhangYuef.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/ZhangYuef">Zhang Yuefeng</a>
<p>Core Dev<br>
Peking University <br> <br> </p>​
<p> Peking University </p>​
</td>
</tr>
<tr align="center">
<td>
​ <a href="https://github.com/JyiHUO"><img width="70" height="70" src="https://github.com/JyiHUO.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/JyiHUO">Huo Junyi</a>
<p>Core Dev<br>
<p>
University of Southampton <br> <br> </p>​
</td>
<td>
​ <a href="https://github.com/Zengai"><img width="70" height="70" src="https://github.com/Zengai.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/Zengai">Zeng Kai</a> ​
<p>Dev<br>
<p>
SenseTime <br> <br> </p>​
</td>
<td>
​ <a href="https://github.com/chenkkkk"><img width="70" height="70" src="https://github.com/chenkkkk.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/chenkkkk">Chen K</a> ​
<p>Dev<br>
<p>
NetEase <br> <br> </p>​
</td>
<td>
​ <a href="https://github.com/WeiyuCheng"><img width="70" height="70" src="https://github.com/WeiyuCheng.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/WeiyuCheng">Cheng Weiyu</a> ​
<p>Dev<br>
<p>
Shanghai Jiao Tong University</p>​
</td>
<td>
​ <a href="https://github.com/tangaqi"><img width="70" height="70" src="https://github.com/tangaqi.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/tangaqi">Tang</a>
<p>Test<br>
<p>
Tongji University <br> <br> </p>​
</td>
</tr>
Expand Down
2 changes: 1 addition & 1 deletion deepctr_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from . import models
from .utils import check_version

__version__ = '0.2.7'
__version__ = '0.2.8'
check_version(__version__)
10 changes: 7 additions & 3 deletions deepctr_torch/models/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Author:
Weichen Shen,weichenswc@163.com
zanshuxun, zanshuxun@aliyun.com
"""
from __future__ import print_function
Expand Down Expand Up @@ -75,7 +76,7 @@ def forward(self, X, sparse_feat_refine_weight=None):

sparse_embedding_list += varlen_embedding_list

linear_logit = torch.zeros([X.shape[0], 1]).to(sparse_embedding_list[0].device)
linear_logit = torch.zeros([X.shape[0], 1]).to(self.device)
if len(sparse_embedding_list) > 0:
sparse_embedding_cat = torch.cat(sparse_embedding_list, dim=-1)
if sparse_feat_refine_weight is not None:
Expand Down Expand Up @@ -476,6 +477,10 @@ def _log_loss(self, y_true, y_pred, eps=1e-7, normalize=True, sample_weight=None
sample_weight,
labels)

@staticmethod
def _accuracy_score(y_true, y_pred):
return accuracy_score(y_true, np.where(y_pred > 0.5, 1, 0))

def _get_metrics(self, metrics, set_eps=False):
metrics_ = {}
if metrics:
Expand All @@ -490,8 +495,7 @@ def _get_metrics(self, metrics, set_eps=False):
if metric == "mse":
metrics_[metric] = mean_squared_error
if metric == "accuracy" or metric == "acc":
metrics_[metric] = lambda y_true, y_pred: accuracy_score(
y_true, np.where(y_pred > 0.5, 1, 0))
metrics_[metric] = self._accuracy_score
self.metrics_names.append(metric)
return metrics_

Expand Down
Binary file added docs/pics/code2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pics/planet_github.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion docs/requirements.readthedocs.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
Cython>=0.28.5
tensorflow==1.15.4
tensorflow==2.7.2
scikit-learn==1.0
4 changes: 2 additions & 2 deletions docs/source/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ To save/load weights:

```python
import torch
model = DeepFM()
model = DeepFM(...)
torch.save(model.state_dict(), 'DeepFM_weights.h5')
model.load_state_dict(torch.load('DeepFM_weights.h5'))
```
Expand All @@ -15,7 +15,7 @@ To save/load models:

```python
import torch
model = DeepFM()
model = DeepFM(...)
torch.save(model, 'DeepFM.h5')
model = torch.load('DeepFM.h5')
```
Expand Down
3 changes: 2 additions & 1 deletion docs/source/History.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# History
- 06/14/2021 : [v0.2.7](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.6) released.Add [AFN](./Features.html#afn-adaptive-factorization-network-learning-adaptive-order-feature-interactions) and fix some bugs.
- 06/19/2022 : [v0.2.8](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.8) released.Fix some bugs.
- 06/14/2021 : [v0.2.7](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.7) released.Add [AFN](./Features.html#afn-adaptive-factorization-network-learning-adaptive-order-feature-interactions) and fix some bugs.
- 04/04/2021 : [v0.2.6](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.6) released.Add [IFM](./Features.html#ifm-input-aware-factorization-machine) and [DIFM](./Features.html#difm-dual-input-aware-factorization-machine);Support multi-gpus running([example](./FAQ.html#how-to-run-the-demo-with-multiple-gpus)).
- 02/12/2021 : [v0.2.5](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.5) released.Fix bug in DCN-M.
- 12/05/2020 : [v0.2.4](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.4) released.Imporve compatibility & fix issues.Add History callback.([example](https://deepctr-torch.readthedocs.io/en/latest/FAQ.html#set-learning-rate-and-use-earlystopping)).
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = '0.2.7'
release = '0.2.8'


# -- General configuration ---------------------------------------------------
Expand Down
9 changes: 6 additions & 3 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,21 @@ You can read the latest code at https://github.com/shenweichen/DeepCTR-Torch and

News
-----
06/19/2022 : Fix some bugs. `Changelog <https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.8>`_

06/14/2021 : Add `AFN <./Features.html#afn-adaptive-factorization-network-learning-adaptive-order-feature-interactions>`_ and fix some bugs. `Changelog <https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.7>`_

04/04/2021 : Add `IFM <./Features.html#ifm-input-aware-factorization-machine>`_ and `DIFM <./Features.html#difm-dual-input-aware-factorization-machine>`_ . Support multi-gpus running(`example <./FAQ.html#how-to-run-the-demo-with-multiple-gpus>`_). `Changelog <https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.6>`_

02/12/2021 : Fix bug in DCN-M. `Changelog <https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.4>`_

DisscussionGroup
-----------------------

公众号:**浅梦学习笔记** wechat ID: **deepctrbot**
公众号:**浅梦学习笔记** wechat ID: **deepctrbot**

`Discussions <https://github.com/shenweichen/DeepCTR/discussions>`_ `学习小组主题集合 <https://mp.weixin.qq.com/mp/appmsgalbum?__biz=MjM5MzY4NzE3MA==&action=getalbum&album_id=1361647041096843265&scene=126#wechat_redirect>`_

.. image:: ../pics/code.png
.. image:: ../pics/code2.jpg

.. toctree::
:maxdepth: 2
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
long_description = fh.read()

REQUIRED_PACKAGES = [
'torch>=1.1.0', 'tqdm', 'sklearn', 'tensorflow'
'torch>=1.1.0', 'tqdm', 'scikit-learn', 'tensorflow'
]

setuptools.setup(
name="deepctr-torch",
version="0.2.7",
version="0.2.8",
author="Weichen Shen",
author_email="weichenswc@163.com",
description="Easy-to-use,Modular and Extendible package of deep learning based CTR(Click Through Rate) prediction models with PyTorch",
Expand Down Expand Up @@ -37,6 +37,7 @@
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
Expand Down
24 changes: 18 additions & 6 deletions tests/models/DeepFM_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,33 @@


@pytest.mark.parametrize(
'use_fm,hidden_size,sparse_feature_num',
[(True, (32,), 3),
(False, (32,), 3),
(False, (32,), 2), (False, (32,), 1), (True, (), 1), (False, (), 2)
'use_fm,hidden_size,sparse_feature_num,dense_feature_num',
[(True, (32,), 3, 3),
(False, (32,), 3, 3),
(False, (32,), 2, 2),
(False, (32,), 1, 1),
(True, (), 1, 1),
(False, (), 2, 2),
(True, (32,), 0, 3),
(True, (32,), 3, 0),
(False, (32,), 0, 3),
(False, (32,), 3, 0),
]
)
def test_DeepFM(use_fm, hidden_size, sparse_feature_num):
def test_DeepFM(use_fm, hidden_size, sparse_feature_num, dense_feature_num):
model_name = "DeepFM"
sample_size = SAMPLE_SIZE
x, y, feature_columns = get_test_data(
sample_size, sparse_feature_num=sparse_feature_num, dense_feature_num=sparse_feature_num)
sample_size, sparse_feature_num=sparse_feature_num, dense_feature_num=dense_feature_num)

model = DeepFM(feature_columns, feature_columns, use_fm=use_fm,
dnn_hidden_units=hidden_size, dnn_dropout=0.5, device=get_device())
check_model(model, model_name, x, y)

# no linear part
model = DeepFM([], feature_columns, use_fm=use_fm,
dnn_hidden_units=hidden_size, dnn_dropout=0.5, device=get_device())
check_model(model, model_name + '_no_linear', x, y)

if __name__ == "__main__":
pass
Loading

0 comments on commit 2cd84f3

Please sign in to comment.