Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug fix]: fix bug of multiple losses of rank model #335

Merged
merged 7 commits into from
Feb 3, 2023
46 changes: 33 additions & 13 deletions easy_rec/python/model/multi_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,35 @@ def _init_towers(self, task_tower_configs):
def _add_to_prediction_dict(self, output):
for task_tower_cfg in self._task_towers:
tower_name = task_tower_cfg.tower_name
self._prediction_dict.update(
self._output_to_prediction_impl(
output[tower_name],
loss_type=task_tower_cfg.loss_type,
num_class=task_tower_cfg.num_class,
suffix='_%s' % tower_name))
if len(task_tower_cfg.losses) == 0:
self._prediction_dict.update(
self._output_to_prediction_impl(
output[tower_name],
loss_type=task_tower_cfg.loss_type,
num_class=task_tower_cfg.num_class,
suffix='_%s' % tower_name))
else:
for loss in task_tower_cfg.losses:
self._prediction_dict.update(
self._output_to_prediction_impl(
output[tower_name],
loss_type=loss.loss_type,
num_class=task_tower_cfg.num_class,
suffix='_%s' % tower_name))

def build_metric_graph(self, eval_config):
"""Build metric graph for multi task model."""
metric_dict = {}
for task_tower_cfg in self._task_towers:
tower_name = task_tower_cfg.tower_name
for metric in task_tower_cfg.metrics_set:
loss_types = {task_tower_cfg.loss_type}
if len(task_tower_cfg.losses) > 0:
loss_types = {loss.loss_type for loss in task_tower_cfg.losses}
metric_dict.update(
self._build_metric_impl(
metric,
loss_type=task_tower_cfg.loss_type,
loss_type=loss_types,
label_name=self._label_name_dict[tower_name],
num_class=task_tower_cfg.num_class,
suffix='_%s' % tower_name))
Expand Down Expand Up @@ -123,9 +135,17 @@ def get_outputs(self):
outputs = []
for task_tower_cfg in self._task_towers:
tower_name = task_tower_cfg.tower_name
outputs.extend(
self._get_outputs_impl(
task_tower_cfg.loss_type,
task_tower_cfg.num_class,
suffix='_%s' % tower_name))
return outputs
if len(task_tower_cfg.losses) == 0:
outputs.extend(
self._get_outputs_impl(
task_tower_cfg.loss_type,
task_tower_cfg.num_class,
suffix='_%s' % tower_name))
else:
for loss in task_tower_cfg.losses:
outputs.extend(
self._get_outputs_impl(
loss.loss_type,
task_tower_cfg.num_class,
suffix='_%s' % tower_name))
return list(set(outputs))
71 changes: 48 additions & 23 deletions easy_rec/python/model/rank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,15 @@ def _output_to_prediction_impl(self,
return prediction_dict

def _add_to_prediction_dict(self, output):
prediction_dict = self._output_to_prediction_impl(
output, loss_type=self._loss_type, num_class=self._num_class)
self._prediction_dict.update(prediction_dict)
if len(self._losses) == 0:
prediction_dict = self._output_to_prediction_impl(
output, loss_type=self._loss_type, num_class=self._num_class)
self._prediction_dict.update(prediction_dict)
else:
for loss in self._losses:
prediction_dict = self._output_to_prediction_impl(
output, loss_type=loss.loss_type, num_class=self._num_class)
self._prediction_dict.update(prediction_dict)

def build_rtp_output_dict(self):
"""Forward tensor as `rank_predict`, which is a special node for RTP."""
Expand All @@ -85,15 +91,22 @@ def build_rtp_output_dict(self):
rank_predict = op.outputs[0]
except KeyError:
forwarded = None
if self._loss_type == LossType.CLASSIFICATION:
loss_types = {self._loss_type}
if len(self._losses) > 0:
loss_types = {loss.loss_type for loss in self._losses}
binary_loss_set = {
LossType.CLASSIFICATION, LossType.F1_REWEIGHTED_LOSS,
LossType.PAIR_WISE_LOSS
}
if loss_types & binary_loss_set:
if 'probs' in self._prediction_dict:
forwarded = self._prediction_dict['probs']
else:
raise ValueError(
'failed to build RTP rank_predict output: classification model ' +
"expect 'probs' prediction, which is not found. Please check if" +
' build_predict_graph() is called.')
elif self._loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
elif loss_types & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}:
if 'y' in self._prediction_dict:
forwarded = self._prediction_dict['y']
else:
Expand All @@ -105,7 +118,7 @@ def build_rtp_output_dict(self):
else:
logging.warning(
'failed to build RTP rank_predict: unsupported loss type {}'.foramt(
self._loss_type))
loss_types))
if forwarded is not None:
rank_predict = tf.identity(forwarded, name='rank_predict')
if rank_predict is not None:
Expand Down Expand Up @@ -183,6 +196,8 @@ def _build_metric_impl(self,
label_name,
num_class=1,
suffix=''):
if not isinstance(loss_type, set):
loss_type = {loss_type}
from easy_rec.python.core.easyrec_metrics import metrics_tf
from easy_rec.python.core import metrics as metrics_lib
binary_loss_set = {
Expand All @@ -191,8 +206,7 @@ def _build_metric_impl(self,
}
metric_dict = {}
if metric.WhichOneof('metric') == 'auc':
assert loss_type in binary_loss_set

assert loss_type & binary_loss_set
if num_class == 1:
label = tf.to_int64(self._labels[label_name])
metric_dict['auc' + suffix] = metrics_tf.auc(
Expand All @@ -208,7 +222,7 @@ def _build_metric_impl(self,
else:
raise ValueError('Wrong class number')
elif metric.WhichOneof('metric') == 'gauc':
assert loss_type in binary_loss_set
assert loss_type & binary_loss_set
if num_class == 1:
label = tf.to_int64(self._labels[label_name])
uids = self._feature_dict[metric.gauc.uid_field]
Expand All @@ -231,7 +245,7 @@ def _build_metric_impl(self,
else:
raise ValueError('Wrong class number')
elif metric.WhichOneof('metric') == 'session_auc':
assert loss_type in binary_loss_set
assert loss_type & binary_loss_set
if num_class == 1:
label = tf.to_int64(self._labels[label_name])
metric_dict['session_auc' + suffix] = metrics_lib.session_auc(
Expand All @@ -249,7 +263,7 @@ def _build_metric_impl(self,
else:
raise ValueError('Wrong class number')
elif metric.WhichOneof('metric') == 'max_f1':
assert loss_type in binary_loss_set
assert loss_type & binary_loss_set
if num_class == 1:
label = tf.to_int64(self._labels[label_name])
metric_dict['max_f1' + suffix] = metrics_lib.max_f1(
Expand All @@ -261,50 +275,50 @@ def _build_metric_impl(self,
else:
raise ValueError('Wrong class number')
elif metric.WhichOneof('metric') == 'recall_at_topk':
assert loss_type in binary_loss_set
assert loss_type & binary_loss_set
assert num_class > 1
label = tf.to_int64(self._labels[label_name])
metric_dict['recall_at_topk' + suffix] = metrics_tf.recall_at_k(
label, self._prediction_dict['logits' + suffix],
metric.recall_at_topk.topk)
elif metric.WhichOneof('metric') == 'mean_absolute_error':
label = tf.to_float(self._labels[label_name])
if loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}:
metric_dict['mean_absolute_error' +
suffix] = metrics_tf.mean_absolute_error(
label, self._prediction_dict['y' + suffix])
elif loss_type == LossType.CLASSIFICATION and num_class == 1:
elif loss_type & {LossType.CLASSIFICATION} and num_class == 1:
metric_dict['mean_absolute_error' +
suffix] = metrics_tf.mean_absolute_error(
label, self._prediction_dict['probs' + suffix])
else:
assert False, 'mean_absolute_error is not supported for this model'
elif metric.WhichOneof('metric') == 'mean_squared_error':
label = tf.to_float(self._labels[label_name])
if loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}:
metric_dict['mean_squared_error' +
suffix] = metrics_tf.mean_squared_error(
label, self._prediction_dict['y' + suffix])
elif num_class == 1 and loss_type in binary_loss_set:
elif num_class == 1 and loss_type & binary_loss_set:
metric_dict['mean_squared_error' +
suffix] = metrics_tf.mean_squared_error(
label, self._prediction_dict['probs' + suffix])
else:
assert False, 'mean_squared_error is not supported for this model'
elif metric.WhichOneof('metric') == 'root_mean_squared_error':
label = tf.to_float(self._labels[label_name])
if loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}:
metric_dict['root_mean_squared_error' +
suffix] = metrics_tf.root_mean_squared_error(
label, self._prediction_dict['y' + suffix])
elif loss_type == LossType.CLASSIFICATION and num_class == 1:
elif loss_type & {LossType.CLASSIFICATION} and num_class == 1:
metric_dict['root_mean_squared_error' +
suffix] = metrics_tf.root_mean_squared_error(
label, self._prediction_dict['probs' + suffix])
else:
assert False, 'root_mean_squared_error is not supported for this model'
elif metric.WhichOneof('metric') == 'accuracy':
assert loss_type == LossType.CLASSIFICATION
assert loss_type & {LossType.CLASSIFICATION}
assert num_class > 1
label = tf.to_int64(self._labels[label_name])
metric_dict['accuracy' + suffix] = metrics_tf.accuracy(
Expand All @@ -313,20 +327,24 @@ def _build_metric_impl(self,

def build_metric_graph(self, eval_config):
metric_dict = {}
loss_types = {self._loss_type}
if len(self._losses) > 0:
loss_types = {loss.loss_type for loss in self._losses}
for metric in eval_config.metrics_set:
metric_dict.update(
self._build_metric_impl(
metric,
loss_type=self._loss_type,
loss_type=loss_types,
label_name=self._label_name,
num_class=self._num_class))
return metric_dict

def _get_outputs_impl(self, loss_type, num_class=1, suffix=''):
if loss_type in [
binary_loss_set = {
LossType.CLASSIFICATION, LossType.F1_REWEIGHTED_LOSS,
LossType.PAIR_WISE_LOSS
]:
}
if loss_type in binary_loss_set:
if num_class == 1:
return ['probs' + suffix, 'logits' + suffix]
else:
Expand All @@ -340,4 +358,11 @@ def _get_outputs_impl(self, loss_type, num_class=1, suffix=''):
raise ValueError('invalid loss type: %s' % LossType.Name(loss_type))

def get_outputs(self):
return self._get_outputs_impl(self._loss_type, self._num_class)
if len(self._losses) == 0:
return self._get_outputs_impl(self._loss_type, self._num_class)

all_outputs = []
for loss in self._losses:
outputs = self._get_outputs_impl(loss.loss_type, self._num_class)
all_outputs.extend(outputs)
return list(set(all_outputs))