-
Notifications
You must be signed in to change notification settings - Fork 344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CMBF model #217
Add CMBF model #217
Conversation
easy_rec/python/model/rank_model.py
Outdated
@@ -300,7 +322,7 @@ def _build_distribute_metric_impl(self, | |||
else: | |||
raise ValueError('Wrong class number') | |||
elif metric.WhichOneof('metric') == 'gauc': | |||
assert loss_type == LossType.CLASSIFICATION | |||
assert loss_type in [LossType.CLASSIFICATION, LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个有没有更好的方式来实现,这样每次新加一个 loss 函数,都需要在每个 metric 下面都加一下,感觉有点繁琐
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,已经修改了
easy_rec/python/model/cmbf.py
Outdated
attention_probs_dropout_prob=self._model_config.attention_probs_dropout_prob, | ||
name='text_self_attention' | ||
) # shape: [batch_size, txt_seq_length, hidden_size] | ||
print('txt_attention_fea:', txt_attention_fea.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print 可以改成 logging 好像更合适一点
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print打印在stdout,比较清晰地查看;如果用logging在打印在stderr,stderr里面有很多系统输出,日志量非常大,不容易发现需要被注意的关键信息。
@@ -27,7 +27,7 @@ def __init__(self, | |||
labels, is_training) | |||
self._loss_type = self._model_config.loss_type | |||
self._num_class = self._model_config.num_class | |||
|
|||
self._losses = self._model_config.losses |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么不能用 self._losses 而要额外定义一个 self._losses
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
losses是为了支持同时使用多个损失函数,而且可以为每个损失函数配置不同的相对权重。
@@ -71,6 +71,8 @@ message FeatureConfig { | |||
|
|||
// delimeter to separate sequence multi-values | |||
optional string seq_multi_sep = 101; | |||
// truncate sequence data to max_seq_len | |||
optional uint32 max_seq_len = 102; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
序列特征不需要这一个值,可以自动算出来的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个值是用来截断序列的,如果序列长度很大,则后续模型的计算量会很大。
CI PY3 Test Failed |
CI Test Failed |
CI PY3 Test Failed |
CI Test Failed |
CI PY3 Test Failed |
CI Test Failed |
CI Test Passed |
CI Test Passed |
CI PY3 Test Passed |
CI PY3 Test Failed |
CI Test Passed |
CI PY3 Test Passed |
CI PY3 Test Passed |
CI Test Passed |
# | ||
# @unittest.skipIf(gl is None, 'graphlearn is not installed') | ||
# def test_dssm_neg_sampler_v2(self): | ||
# self._success = test_utils.test_single_train_eval( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these unit tests should not be removed.
No description provided.