Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test for whether TransformModules work for density estimation #2544

Merged
merged 10 commits into from
Jul 7, 2020

Conversation

stefanwebb
Copy link
Contributor

I've added a test for whether TransformModule's work for density estimation after discovering that Spline doesn't... I'll fix the bug in Spline in this PR

@stefanwebb stefanwebb marked this pull request as ready for review July 2, 2020 20:34
@stefanwebb
Copy link
Contributor Author

I worked out that the bug was that several transforms didn't update their parameters for each call of the forward or inverse operation. This is a problem after you've taken a gradient step during learning...

We now test that density estimation is possible and it passes for all transforms!

@stefanwebb
Copy link
Contributor Author

@fritzo when this merges it will unblock #2542

pyro/distributions/transforms/householder.py Outdated Show resolved Hide resolved
# u_unnormed ~ (count_transforms, input_dim)
# Hence, input_dim must divide
u_unnormed = self.nn(context)
if self.count_transforms == 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not part of this PR, but it is an antipattern to make return type depend on a scalar value. Doing so breaks the ability to write generic code using this interface. Has this behavior been released yet?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I'm afraid it's been released... Can you elaborate on why this would break the ability to write generic code? I think the underlying problem is actually how DenseNN shapes it's outputs, and this is something we should think about more carefully during a major refactoring

Copy link
Member

@fritzo fritzo Jul 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: the following is highly subjective

An example of some generic code you might like to write is, say, a ConditionalCatTransform that concatenates two conditional transforms. You might want to fuse their neural nets to save on tensor ops and share strength. How would you do this? Maybe

class ConditionalCatTransform(...):
    def __init__(self, parts):
        super().__init__(self, cache_size=1)
        self.parts = parts  # ignoring ModuleList magic for demo purposes
        self.nn = memoize(cat_dense_nn([part.nn for part in self.parts]))
        end = 0
        for part in self.parts:
            beg, end = end, end + part.count_transforms
            # The following line assumes DenseNN returns a tuple:
            part.nn = lambda context: self.nn(context)[beg:end]

then we could define helpers memoize = functools.lru_cache(max_size=1) and

def cat_dense_nn(parts):
    input_dims = parts[0].input_dims
    hidden_dims = parts[0].hidden_dims
    param_dims = sum([part.param_dims for part in parts], [])
    return DenseNN(input_dims, hidden_dims, param_dims) 

That's what we could have written if the types were consistent. It would have been pretty simple generic code. But it looks like DenseNN returns an output type that depends on an int value, so we would need to complicate our wrapper with extra logic:

  class ConditionalCatTransform(...):
      def __init__(self, parts):
          ...
          for part in self.parts:
              beg, end = end, end + part.count_transforms
-             part.nn = lambda context: self.nn(context)[beg:end]
+             if beg + 1 == end:
+                 part.nn = lamba context: self.nn(context)[beg]
+             else:
+                 part.nn = lambda context: self.nn(context)[beg:end]

Now that's just a little more complex. But I feel, as an author of abstract code, that the complexity tax is best paid by one-off code, so that abstractions can be built tax free. I think this tax structure is one of the main differences between programming languages, e.g. R and MATLAB tend to tax library code, whereas Python and C tend to tax one-off code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, let's think about this when we have a chance to overhaul DenseNN

@stefanwebb stefanwebb requested a review from fritzo July 2, 2020 23:22
@stefanwebb
Copy link
Contributor Author

@fritzo I'm not sure what's wrong with the AIR tutorial in Travis...

@fritzo
Copy link
Member

fritzo commented Jul 5, 2020

@neerajprad Do you have any ideas why the air tutorial might be failing? I haven't been able to reproduce locally.

@neerajprad
Copy link
Member

I think #2549 should temporarily silence the FutureWarning from jupyter_client that's causing the tutorial tests to fail.

@fritzo
Copy link
Member

fritzo commented Jul 6, 2020

@stefanwebb can you try merging in dev and pushing to trigger ci?

@stefanwebb stefanwebb requested a review from fritzo July 7, 2020 00:02
@stefanwebb
Copy link
Contributor Author

Thanks @neerajprad! It all passes now 😄

@fritzo fritzo merged commit 939c04d into pyro-ppl:dev Jul 7, 2020
@stefanwebb stefanwebb deleted the autodiff-bug branch July 7, 2020 15:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants