-
-
Notifications
You must be signed in to change notification settings - Fork 112
A simple solution to work on the copy of the graph #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b218d43
88fc3c6
8136428
3fd3acc
0032173
f3b6715
72cb7ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,7 @@ def __getattr__(self, item): | |
raise error from e | ||
|
||
|
||
class Model(object): | ||
class Model(): | ||
def __init__(self, name=None, graph=None, session=None, **config): | ||
self._cfg = Config(**config) | ||
self.name = name | ||
|
@@ -38,6 +38,7 @@ def __init__(self, name=None, graph=None, session=None, **config): | |
session = tf.Session(graph=graph) | ||
self.session = session | ||
self.observe(**config) | ||
self.temp_graph = tf.Graph() | ||
|
||
def define(self, f): | ||
self._f = f | ||
|
@@ -54,6 +55,7 @@ def _init_variables(self): | |
info_collector = interceptors.CollectVariablesInfo() | ||
with self.graph.as_default(), ed.interception(info_collector): | ||
self._f(self.cfg) | ||
tf.contrib.graph_editor.copy(self.graph, self.temp_graph) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So 1) You create a temp graph for inner manipulations |
||
self._variables = info_collector.result | ||
|
||
def test_point(self, sample=True): | ||
|
@@ -66,12 +68,15 @@ def not_observed(var, *args, **kwargs): # pylint: disable=unused-argument | |
def get_mode(state, rv, *args, **kwargs): # pylint: disable=unused-argument | ||
return rv.distribution.mode() | ||
chain.insert(0, interceptors.Generic(after=get_mode)) | ||
|
||
with self.graph.as_default(), ed.interception(interceptors.Chain(*chain)): | ||
tf.contrib.graph_editor.copy(self.graph, self.temp_graph) | ||
# pylint: disable=not-context-manager | ||
with self.temp_graph.as_default(), ed.interception(interceptors.Chain(*chain)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You reuse tmp graph created on step 1) suppose There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not know a solution for this case yet. I've seen some kind of working and relevant example in tensorflow codebase |
||
self._f(self.cfg) | ||
with self.session.as_default(): | ||
returns = self.session.run(list(values_collector.result.values())) | ||
return dict(zip(values_collector.result.keys(), returns)) | ||
# pylint: enable=not-context-manager | ||
with tf.Session(graph=self.temp_graph) as sess: | ||
returns = sess.run(list(values_collector.result.values())) | ||
keys = values_collector.result.keys() | ||
return dict(zip(keys, returns)) | ||
|
||
def target_log_prob_fn(self, *args, **kwargs): # pylint: disable=unused-argument | ||
""" | ||
|
@@ -83,7 +88,8 @@ def log_joint_fn(*args, **kwargs): # pylint: disable=unused-argument | |
states = dict(zip(self.unobserved.keys(), args)) | ||
states.update(self.observed) | ||
interceptor = interceptors.CollectLogProb(states) | ||
with ed.interception(interceptor): | ||
tf.contrib.graph_editor.copy(self.graph, self.temp_graph) | ||
with self.temp_graph.as_default(), ed.interception(interceptor): # pylint: disable=not-context-manager | ||
self._f(self._cfg) | ||
|
||
log_prob = sum(interceptor.log_probs) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ | |
VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape,rv') | ||
|
||
|
||
class Interceptor(object): | ||
class Interceptor(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good practice to inherit from object, even if you don't need to! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? I thought in py3 we don't need that anymore. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I am facing a linting error when I import from |
||
def name_scope(self): | ||
return tf.name_scope(self.__class__.__name__.lower()) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would use this pattern for creating and destroying the temporary graph. I'm using a pymc3 model here because I think the pattern is similar, but let me know if it isn't! Note that you will only need to implement the equivalent of the
temp_model
method below.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
decorator for inclass function is causing problems.. I will look into