-
-
Notifications
You must be signed in to change notification settings - Fork 987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test for whether TransformModule
s work for density estimation
#2544
Conversation
I worked out that the bug was that several transforms didn't update their parameters for each call of the forward or inverse operation. This is a problem after you've taken a gradient step during learning... We now test that density estimation is possible and it passes for all transforms! |
# u_unnormed ~ (count_transforms, input_dim) | ||
# Hence, input_dim must divide | ||
u_unnormed = self.nn(context) | ||
if self.count_transforms == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not part of this PR, but it is an antipattern to make return type depend on a scalar value. Doing so breaks the ability to write generic code using this interface. Has this behavior been released yet?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I'm afraid it's been released... Can you elaborate on why this would break the ability to write generic code? I think the underlying problem is actually how DenseNN
shapes it's outputs, and this is something we should think about more carefully during a major refactoring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: the following is highly subjective
An example of some generic code you might like to write is, say, a ConditionalCatTransform
that concatenates two conditional transforms. You might want to fuse their neural nets to save on tensor ops and share strength. How would you do this? Maybe
class ConditionalCatTransform(...):
def __init__(self, parts):
super().__init__(self, cache_size=1)
self.parts = parts # ignoring ModuleList magic for demo purposes
self.nn = memoize(cat_dense_nn([part.nn for part in self.parts]))
end = 0
for part in self.parts:
beg, end = end, end + part.count_transforms
# The following line assumes DenseNN returns a tuple:
part.nn = lambda context: self.nn(context)[beg:end]
then we could define helpers memoize = functools.lru_cache(max_size=1)
and
def cat_dense_nn(parts):
input_dims = parts[0].input_dims
hidden_dims = parts[0].hidden_dims
param_dims = sum([part.param_dims for part in parts], [])
return DenseNN(input_dims, hidden_dims, param_dims)
That's what we could have written if the types were consistent. It would have been pretty simple generic code. But it looks like DenseNN
returns an output type that depends on an int value, so we would need to complicate our wrapper with extra logic:
class ConditionalCatTransform(...):
def __init__(self, parts):
...
for part in self.parts:
beg, end = end, end + part.count_transforms
- part.nn = lambda context: self.nn(context)[beg:end]
+ if beg + 1 == end:
+ part.nn = lamba context: self.nn(context)[beg]
+ else:
+ part.nn = lambda context: self.nn(context)[beg:end]
Now that's just a little more complex. But I feel, as an author of abstract code, that the complexity tax is best paid by one-off code, so that abstractions can be built tax free. I think this tax structure is one of the main differences between programming languages, e.g. R and MATLAB tend to tax library code, whereas Python and C tend to tax one-off code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, let's think about this when we have a chance to overhaul DenseNN
@fritzo I'm not sure what's wrong with the AIR tutorial in Travis... |
@neerajprad Do you have any ideas why the air tutorial might be failing? I haven't been able to reproduce locally. |
I think #2549 should temporarily silence the FutureWarning from jupyter_client that's causing the tutorial tests to fail. |
@stefanwebb can you try merging in dev and pushing to trigger ci? |
Thanks @neerajprad! It all passes now 😄 |
I've added a test for whether TransformModule's work for density estimation after discovering that
Spline
doesn't... I'll fix the bug inSpline
in this PR