Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A simple solution to work on the copy of the graph #15

Closed
wants to merge 7 commits into from

Conversation

sharanry
Copy link
Contributor

No description provided.

@sharanry sharanry changed the base branch from master to functional July 25, 2018 06:12
@ferrine
Copy link
Member

ferrine commented Jul 25, 2018

Tests seem to be broken, do they pass localy? Solution looks nice, are there any corner cases? What's happening when we use nested models?

@ferrine
Copy link
Member

ferrine commented Jul 25, 2018

Also what if self._cfg contains tf variables?

@sharanry
Copy link
Contributor Author

@ferrine, some problems with using the the right graph with Session(), that is why test is failing. Fixing it.

self._cfg contains the model configuration passed either during model initialization or using model.configure()

@sharanry sharanry changed the title A simple solution to work on the copy of the graph [WIP] A simple solution to work on the copy of the graph Jul 25, 2018
@ColCarroll
Copy link
Member

Looks nice - once tests are fixed, it might be worth wrapping this in its own decorator or method.

with ed.interception(interceptor):
temp_graph = tf.Graph()
tf.contrib.graph_editor.copy(self.graph, temp_graph)
with temp_graph.as_default(), ed.interception(interceptor):
self._f(self._cfg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR, but this line should probably be self._f(self.cfg), right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it shouldnt make a difference. self.cfg returns self._cfg. Although I am not sure why we are saving it as _cfg

@sharanry
Copy link
Contributor Author

The CI is failing because they removed tf-nightly==1.8.0.dev20180331. Need to update the requirements.

@coveralls
Copy link

Pull Request Test Coverage Report for Build 53

  • 18 of 18 (100.0%) changed or added relevant lines in 3 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+0.4%) to 85.0%

Totals Coverage Status
Change from base Build 42: 0.4%
Covered Lines: 204
Relevant Lines: 240

💛 - Coveralls

@sharanry sharanry changed the title [WIP] A simple solution to work on the copy of the graph A simple solution to work on the copy of the graph Jul 26, 2018
Copy link
Member

@ColCarroll ColCarroll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for late review!

@@ -38,6 +38,7 @@ def __init__(self, name=None, graph=None, session=None, **config):
session = tf.Session(graph=graph)
self.session = session
self.observe(**config)
self.temp_graph = tf.Graph()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use this pattern for creating and destroying the temporary graph. I'm using a pymc3 model here because I think the pattern is similar, but let me know if it isn't! Note that you will only need to implement the equivalent of the temp_model method below.

import pymc3 as pm


def check_in_model_context():
    try:
        m = pm.modelcontext(None)
        return f'In model context! {len(m.vars)} vars!'
    except TypeError:
        return 'Not in model context!'

class Model(object):
    def __init__(self):
        self._temp_model = None
        
    @contextmanager
    def temp_model(self):
        self._temp_model = pm.Model()
        try:
            with self._temp_model:
                yield
        finally:
            self._temp_model = None
    
    def calculation(self):
        print(check_in_model_context())
        with self.temp_model():
            pm.Normal('x', 0, 1)
            print(check_in_model_context())
            
        print(check_in_model_context())
        with self.temp_model():
            print(check_in_model_context())
        print(check_in_model_context())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decorator for inclass function is causing problems.. I will look into

@@ -12,7 +12,7 @@
VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape,rv')


class Interceptor(object):
class Interceptor():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good practice to inherit from object, even if you don't need to!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? I thought in py3 we don't need that anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am facing a linting error when I import from object

@ferrine
Copy link
Member

ferrine commented Jul 29, 2018 via email

@twiecki
Copy link
Member

twiecki commented Jul 30, 2018

Explicit is better than implicit

There is nothing explicit about this, it's a python 2 hack to get super working (or something like that).

@ColCarroll
Copy link
Member

Sort of a jokey discussion, but some nuggets of wisdom in there, notably that there aren't strong feelings either way:
https://forum.dabeaz.com/t/inheriting-from-object/161

I'm fine leaving it out.

Copy link
Member

@ferrine ferrine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proposed approach has important corner cases not supported

@@ -54,6 +55,7 @@ def _init_variables(self):
info_collector = interceptors.CollectVariablesInfo()
with self.graph.as_default(), ed.interception(info_collector):
self._f(self.cfg)
tf.contrib.graph_editor.copy(self.graph, self.temp_graph)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So 1) You create a temp graph for inner manipulations

with self.graph.as_default(), ed.interception(interceptors.Chain(*chain)):
tf.contrib.graph_editor.copy(self.graph, self.temp_graph)
# pylint: disable=not-context-manager
with self.temp_graph.as_default(), ed.interception(interceptors.Chain(*chain)):
Copy link
Member

@ferrine ferrine Aug 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You reuse tmp graph created on step 1) suppose cfg has any tensor from original graph and you call your model with temp_graph.as_default(). That creates all variables for temp_graph. This looks fine unless the case I'm talking about. If you attempt to do an operation on two variables from different graphs, you will get an exception
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not know a solution for this case yet. I've seen some kind of working and relevant example in tensorflow codebase
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/function.py#L347

@twiecki
Copy link
Member

twiecki commented Jan 5, 2019

Closing as the main prototype is from the london branch.

@twiecki twiecki closed this Jan 5, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants