Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#37 from heavengate/refine_metric
Browse files Browse the repository at this point in the history
refine metric
  • Loading branch information
heavengate authored Apr 15, 2020
2 parents f9f2d42 + da9aca3 commit dc2a5e5
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 17 deletions.
5 changes: 0 additions & 5 deletions examples/bmn/bmn_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ def __init__(self, cfg, mode):
elif self.mode == 'infer':
self.get_infer_dataset_dict()

def add_metric_op(self, preds, label):
pred_bm, pred_start, pred_en = preds
video_index = label[-1]
return [pred_bm, pred_start, pred_en, video_index] #return list

def update(self, pred_bm, pred_start, pred_end, fid):
# generate proposals
pred_start = pred_start[0]
Expand Down
38 changes: 30 additions & 8 deletions hapi/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,16 @@ def reset(self):
format(self.__class__.__name__))

@abc.abstractmethod
def update(self, *args, **kwargs):
def update(self, *args):
"""
Update states for metric
Inputs of :code:`update` is the outputs of :code:`Metric.add_metric_op`,
if :code:`add_metric_op` is not defined, the inputs of :code:`update`
will be flatten arguments of **output** of mode and **label** from data:
:code:`update(output1, output2, ..., label1, label2,...)`
see :code:`Metric.add_metric_op`
"""
raise NotImplementedError("function 'update' not implemented in {}.".
format(self.__class__.__name__))
Expand All @@ -72,11 +79,26 @@ def name(self):
raise NotImplementedError("function 'name' not implemented in {}.".
format(self.__class__.__name__))

def add_metric_op(self, pred, label):
def add_metric_op(self, *args):
"""
Add process op for metric in program
This API is advanced usage to accelerate metric calculating, calulations
from outputs of model to the states which should be updated by Metric can
be defined here, where Paddle OPs is also supported. Outputs of this API
will be the inputs of "Metric.update".
If :code:`add_metric_op` is defined, it will be called with **outputs**
of model and **labels** from data as arguments, all outputs and labels
will be concatenated and flatten and each filed as a separate argument
as follows:
:code:`add_metric_op(output1, output2, ..., label1, label2,...)`
If :code:`add_metric_op` is not defined, default behaviour is to pass
input to output, so output format will be:
:code:`return output1, output2, ..., label1, label2,...`
see :code:`Metric.update`
"""
return pred, label
return args


class Accuracy(Metric):
Expand All @@ -91,12 +113,12 @@ def __init__(self, topk=(1, ), name=None, *args, **kwargs):
self._init_name(name)
self.reset()

def add_metric_op(self, pred, label, *args, **kwargs):
pred = fluid.layers.argsort(pred[0], descending=True)[1][:, :self.maxk]
correct = pred == label[0]
def add_metric_op(self, pred, label, *args):
pred = fluid.layers.argsort(pred, descending=True)[1][:, :self.maxk]
correct = pred == label
return correct

def update(self, correct, *args, **kwargs):
def update(self, correct, *args):
accs = []
for i, k in enumerate(self.topk):
num_corrects = correct[:, :k].sum()
Expand Down
9 changes: 5 additions & 4 deletions hapi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def to_list(value):
if value is None:
return value
if isinstance(value, (list, tuple)):
return value
return list(value)
return [value]


Expand Down Expand Up @@ -473,7 +473,7 @@ def _make_program(self, mode):
if mode != 'test':
for metric in self.model._metrics:
metrics.append(
to_list(metric.add_metric_op(outputs, labels)))
to_list(metric.add_metric_op(*(outputs + labels))))

if mode == 'train' and self.model._optimizer:
self._loss_endpoint = fluid.layers.sum(losses)
Expand Down Expand Up @@ -593,7 +593,7 @@ def train(self, inputs, labels=None):
metrics = []
for metric in self.model._metrics:
metric_outs = metric.add_metric_op(
to_list(outputs), to_list(labels))
*(to_list(outputs) + to_list(labels)))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)

Expand Down Expand Up @@ -632,7 +632,8 @@ def eval(self, inputs, labels=None):
self._merge_count[self.mode + '_total'] += samples
self._merge_count[self.mode + '_batch'] = samples

metric_outs = metric.add_metric_op(to_list(outputs), labels)
metric_outs = metric.add_metric_op(
*(to_list(outputs) + to_list(labels)))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)

Expand Down

0 comments on commit dc2a5e5

Please sign in to comment.