Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Make streaming metrics resettable #4814

Closed
untom opened this issue Oct 7, 2016 · 24 comments
Closed

[Feature Request] Make streaming metrics resettable #4814

untom opened this issue Oct 7, 2016 · 24 comments
Labels
stat:contribution welcome Status - Contributions welcome type:feature Feature requests

Comments

@untom
Copy link

untom commented Oct 7, 2016

Hi!
Currently, the streaming metrics are (as far as I know) not resettable. I'd like to be able to e.g. reset the counter after each epoch. This way, having e.g. a very bad accuracy in the beginning of training will not still influence the accuracy value ten epochs later. It makes it easier to compare my results to runs obtained outside tensorflow.

The only workaround I found is to do sess.run(tf.initialize_local_variables()) after each epoch, but of course this can have bad side effects if I have other local variables that I don't want to reset.

Or is there a way to achieve what I want that I didn't think of?

@tatatodd tatatodd added enhancement stat:contribution welcome Status - Contributions welcome labels Oct 7, 2016
@tatatodd
Copy link
Contributor

tatatodd commented Oct 7, 2016

@untom Thanks for filing the issue! The request sounds reasonable to me. We welcome contributions!

Maybe @nathansilberman @langmore might know whether there's some existing way to reset streaming metrics; I couldn't come up with anything either.

@AshishBora
Copy link

AshishBora commented Oct 17, 2016

Streaming metrics add two local variables total and count. You can find and reset them to get the required behavior. Please see the example below. Hope this helps.

import tensorflow as tf

value = 0.1
with tf.name_scope('foo'):
    mean_value, update_op = tf.contrib.metrics.streaming_mean(value)

init_op = [tf.initialize_variables(tf.local_variables())]
stream_vars = [i for i in tf.local_variables() if i.name.split('/')[0] == 'foo']
reset_op = [tf.initialize_variables(stream_vars)]
with tf.Session() as sess:
    sess.run(init_op)
    for j in range(3):
        for i in range(9):
            _, total, count = sess.run([update_op] + stream_vars) 
            mean_val = sess.run([mean_value])
            print total, count, mean_val
        sess.run(reset_op)
        print ''

@untom
Copy link
Author

untom commented Oct 19, 2016

Thanks, that is very useful. It seems to me that i would make sense if the metrics would all return a reset_op together with the other ops. But that would break existing code. A workaround could be to only return the reset_op if the user specifies an optional return_reset_op=True parameter when creating the OP. Is that a sensible approach for this?

@littleDing
Copy link

Agree with @untom .
When building a multi-target training graph, there might be tens of prediction to evaluate on tens of labels. It would be nightmare to generate this local_variable reset op manually.
Following @AshishBora 's way, I'am getting the last 4 variable from local_variable list. But this doesn't seems to be extensible.

@aselle aselle added type:feature Feature requests and removed enhancement labels Feb 9, 2017
@astorfi
Copy link

astorfi commented Jun 2, 2017

Thank you @AshishBora for your guidance. However I am encountering an issue using TF-Slim. The problem is that the TF-Slim use tf.train.Supervisor() and after that the graph is finalized and cannot be be modified. I get the following error:

raise RuntimeError("Graph is finalized and cannot be modified.") RuntimeError: Graph is finalized and cannot be modified.

So the above solution won't work.

Is there any other solution?
Thanks

@nathansilberman
Copy link
Contributor

nathansilberman commented Jun 2, 2017

You can run the following to reset the local variables used by metric computation:
session.run(tf.local_variables_initializer())

To avoid the issue of graph finalization, just create a reference to this op BEFORE session creation:

reset_op = tf.local_variables_initializer()
... session created
session.run(reset_op)

@foxik
Copy link
Contributor

foxik commented Jun 2, 2017

Note that this is tightly coupled to #9498.

@astorfi
Copy link

astorfi commented Jun 2, 2017

Thank you @nathansilberman .
Unfortunately it didn't work. Looks like the supervisor is much smarter than that!
Moreover I the the problem with tf.local_variables_initializer() is that by defining the reset_op it may reset all variables.
I am not %100 sure though.
But thank you so much for the hint. I think I may find the solution from this.

@nathansilberman
Copy link
Contributor

nathansilberman commented Jun 2, 2017 via email

@astorfi
Copy link

astorfi commented Jun 2, 2017

@untom I made it. Thanks to @nathansilberman .

I made a pull request to tensorflow for changing slim.learning.train and I added a reset_op for metric.
Although different changes must be made to train_image_classifier.py provided by slim, but the first pull request is necessary to accept.

pull request:
#10400
#10400 (comment)

@petrux
Copy link

petrux commented Jun 22, 2017

@untom I just stumbled upon your discussion. I am struggling with the same issue and I just started wrapping all my code up here. Please, note that it is still a development and unstable branch and may rapidly change in the next few hours. Any feedback or suggestion is more than welcome. The idea is to have a easy way for having both streaming computation and batch-based computation of metrics. I would be carefule with the idea suggested @AshishBora since it sounds like breaking encapsulation -- but I see that's the quickest way to go.

@shoeffner
Copy link

shoeffner commented Jul 12, 2017

Concluding the suggestions by @AshishBora (#4814 (comment)) and @nathansilberman (#4814 (comment)) I came up with this function to create all three ops with one call, while keeping the variables encapsulated (fixed a side effect issue @untom mentioned (#4814 (comment))):

def create_reset_metric(metric, scope='reset_metrics', **metric_args):
  with tf.variable_scope(scope) as scope:
    metric_op, update_op = metric(**metric_args)
    vars = tf.contrib.framework.get_variables(
                 scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
    reset_op = tf.variables_initializer(vars)
  return metric_op, update_op, reset_op

An example to create the operations inside the graph is then like this:

epoch_loss, epoch_loss_update, epoch_loss_reset = create_reset_metric(
                    tf.contrib.metrics.streaming_mean_squared_error, 'epoch_loss',
                    predictions=output, labels=target)

@untom
Copy link
Author

untom commented Jul 12, 2017

does local_variables_initializer only return variables from the current scope? Otherwise this can have bad side effects

@shoeffner
Copy link

Indeed, it has! Thanks, I changed it to use vars = tf.contrib.framework.get_variables(scope, collection=tf.GraphKeys.LOCAL_VARIABLES) now, which already filters by scope.

@unship
Copy link

unship commented Sep 6, 2017

what's the best practice to use this

import tensorflow as tf

pred = tf.placeholder(shape=[3],dtype=tf.int32)
label = tf.placeholder(shape=[3],dtype=tf.int32)
v,op = tf.metrics.accuracy(pred,label)
reset_op = tf.local_variables_initializer()

sess = tf.InteractiveSession()
sess.run(reset_op)
for i in range(10):
    print(sess.run([v,op,reset_op],{pred:[0,1,2],label:[0,1,2]}))

I get

[0.0, 1.0, None]
[1.0, 1.0, None]
[0.0, 0.0, None]
[inf, 1.0, None]
[0.0, 1.0, None]
[0.0, 1.0, None]
[1.0, 1.0, None]
[inf, inf, None]
[0.0, 0.0, None]
[0.0, inf, None]

I add control dependency to this

import tensorflow as tf

pred = tf.placeholder(shape=[3],dtype=tf.int32)
label = tf.placeholder(shape=[3],dtype=tf.int32)
v,op = tf.metrics.accuracy(pred,label)
reset_op = tf.local_variables_initializer()
op = tf.tuple([op],control_inputs=[reset_op])[0]
v = tf.tuple([v],control_inputs=[op])[0]

sess = tf.InteractiveSession()
sess.run(reset_op)
for i in range(10):
    print(sess.run([v],{pred:[0,1,2],label:[0,1,2]}))

results is wired too

[0.0]
[1.0]
[0.0]
[0.0]
[0.0]
[inf]
[0.0]
[1.0]
[1.0]
[1.0]

@shoeffner
Copy link

shoeffner commented Sep 10, 2017

You are doing what I did wrong in the first place: You are using the tf.local_variables_initializer(). Try using another reset op which only resets the variables you need, e.g. by selecting the variables you need:

with tf.variable_scope("reset_metrics_accuracy_scope") as scope:
    v, op = tf.metrics.accuracy(pred, label)
    vars = tf.contrib.framework.get_variables(scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
    reset_op = tf.variables_initializer(vars)

You would then call sess.run(op) whenever your metric should be updated, sess.run(reset_op) whenever you want to reset it, and sess.run(v) whenever you want to know its value.

@b3nk4n
Copy link

b3nk4n commented Oct 17, 2017

@shoeffner : I'm not sure if this has changed in the last version 1.3 of TF, but 'vars' is an empty list when I use your snippet above.

To be more precise, I always run into the exception of my code, which is based on your proposal:

class MetricOp(object):
    def __init__(self, name, value, update, reset):
        self._name = name
        self._value = value
        self._update = update
        self._reset = reset

    @property
    def name(self):
        return self._name

    @property
    def value(self):
        return self._value

    @property
    def update(self):
        return self._update

    @property
    def reset(self):
        return self._reset


def create_metric(scope: str, metric: callable, **metric_args) -> MetricOp:
    with tf.variable_scope(scope) as scope:
        metric_op, update_op = metric(**metric_args)
        scope_vars = tf.contrib.framework.get_variables(
            scope, collection=tf.GraphKeys.LOCAL_VARIABLES)

        if len(scope_vars) == 0:
            raise Exception("No local variables found.")

        reset_op = tf.variables_initializer(scope_vars)
    return MetricOp('MetricOp', metric_op, update_op, reset_op)

Sorry, could fix that by myself. It was caused by wrapping my coll to this function with a name_scope:

with tf.name_scope('Metrics'):  # <---
    targets = tf.argmax(y, 1)
    accuracy_op = metrics.create_metric("Accuracy", tf.metrics.accuracy,
                                        labels=targets, predictions=model.prediction_op)
    precision_op = metrics.create_metric("Precision", tf.metrics.precision,
                                         labels=targets, predictions=model.prediction_op)

@kitamura-tetsuo
Copy link

kitamura-tetsuo commented Nov 21, 2017

@bsautermeister
You can use scope.original_name_scope wrapping with a name_scope instead of scope.

vars = tf.contrib.framework.get_variables(
             scope.original_name_scope, collection=tf.GraphKeys.LOCAL_VARIABLES)

@studentSam0000
Copy link

studentSam0000 commented Dec 14, 2017

@sguada @AshishBora @astorfi

Hi, Thank you for the leads. I am new to TF. I like to use the separate evaluation process in tf.slim with slim.evaluation.evaluation_loop . Is there a workaround to reset the streaming_accuracy after a predefined num_evals if I use slim.evaluation.evaluation_loop?

@shoeffner
Copy link

@studentSam0000 You might want to take a look at this tf slim example, where they already use some metrics. However, if you really want to reset your streaming accuracy, you could probably use a hook in which you call the reset op, something along these lines:

class ResetHook(tf.train.SessionRunHook):
    """Hook to perform reset metrics every N steps."""

    def __init__(self, reset_op, every_step=50):
        self.reset_op = reset_op
        self.every_step = every_step
        self.reset = False

    def begin(self):
        self._global_step_tensor = tf.train.get_global_step()
        if self._global_step_tensor is None:
            raise RuntimeError("Global step should be created to use ResetHook.")

    def before_run(self, run_context):
        if self.reset:
            return tf.train.SessionRunArgs(fetches=self.reset_op)
        return tf.train.SessionRunArgs(fetches=self._global_step_tensor)

    def after_run(self, run_context, run_values):
        if self.reset:
            self.reset = False
            return
        global_step = run_values.results
        if global_step % self.every_step == 0:
            self.reset = True

Using with the snippet above (#4814 (comment)) you can then build something like this:

epoch_loss, epoch_loss_update, epoch_loss_reset = create_reset_metric(
                    tf.contrib.metrics.streaming_mean_squared_error, 'epoch_loss',
                    predictions=output, labels=target)

reset_hook = ResetHook(epoch_loss_reset, 10)

tf.contrib.slim.evaluation.evaluation_loop('local', 'checkpoints', 'logs', num_evals=1000, 
    ..., hooks=[reset_hook])

I haven't tested the hook, but I used a similar hook to perform traces a while ago and just adjusted that one a little bit, but you should get the idea of how it works.

@skeggse
Copy link

skeggse commented Mar 20, 2018

I've been using the METRIC_VARIABLES collection as the var_list for variables_initializer to reset metrics with a control dependency before updating them, similar in spirit to @shoeffner's approach.

@foxik
Copy link
Contributor

foxik commented Feb 17, 2020

@gadagashwini With TF 2, this can be closed.

@foxik
Copy link
Contributor

foxik commented May 21, 2021

@rmothukuru With TF 2, this can be now closed (the tf.keras.metrics are resettable).

@rmothukuru
Copy link
Contributor

@foxik,
Thank you for the confirmation. Closing the issue as it has been fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:contribution welcome Status - Contributions welcome type:feature Feature requests
Projects
None yet
Development

No branches or pull requests