Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add slice_axis methods to Distribution #397

Merged
merged 4 commits into from
Oct 17, 2019

Conversation

lostella
Copy link
Contributor

@lostella lostella commented Oct 17, 2019

Description of changes: This allows to slice distributions across arbitrary axes, using the basically the same signatures as mx.nd.slice and mx.nd.slice_axis.

More tests/distributions to come, feedback wanted.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@lostella lostella requested a review from jaheba October 17, 2019 08:54
@lostella lostella changed the title Add slice, slice_axis methods to Distribution Add slice_axis methods to Distribution Oct 17, 2019
Comment on lines 263 to 267
sliced_distr = self.__class__(
*[a.slice_axis(axis, begin, end) for a in self.args]
)
assert isinstance(sliced_distr, type(self))
return sliced_distr
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could use type(self) instead of self.__class__ and return the object directly. However, that way mypy complains that "too many constructor arguments are passed", so it probably infers that type(self) == Distribution instead of the most specific class. This way instead, I get to fool mypy (a-ha!).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the problem is that mypy cannot infer the number of arguments of the specific distribution class. Distribution takes no arguments, so that's why it probably complains.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

src/gluonts/distribution/distribution.py Outdated Show resolved Hide resolved
@@ -174,6 +174,10 @@ def s(bin_probs):

return _sample_multiple(s, self.bin_probs, num_samples=num_samples)

@property
def args(self) -> List:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we also add an abstract definition in the base class?

However, I don't like this args thing. You impose a constraint on the arguments of all distributions that they are all sliceable. Also, by using a list order is important and can easily mixed up.

As an alternative we could have a "dimensions" attribute which returns the name of the fields in question and then use them to construct a new instance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The property is defined in the base class already.

You impose a constraint on the arguments of all distributions that they are all sliceable.

I agree, but that's more about the slice_axis method (we should be careful there I guess) than args.

Also, by using a list order is important and can easily mixed up.

Do you suggest using a dictionary?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's fine for now, although I have some concerns about this axis thing.

Copy link
Contributor Author

@lostella lostella Oct 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A less ambiguous solution (but a more verbose one) would be to define the slice_axis method for each distribution. A compromise could be to have this as fallback definition, and specialize slice_axis in case the args property doesn't work for some reason.

@@ -174,6 +174,10 @@ def s(bin_probs):

return _sample_multiple(s, self.bin_probs, num_samples=num_samples)

@property
def args(self) -> List:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's fine for now, although I have some concerns about this axis thing.

@codecov-io
Copy link

codecov-io commented Oct 17, 2019

Codecov Report

Merging #397 into master will decrease coverage by <.01%.
The diff coverage is 81.48%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #397      +/-   ##
==========================================
- Coverage   80.02%   80.02%   -0.01%     
==========================================
  Files         146      146              
  Lines        8402     8424      +22     
==========================================
+ Hits         6724     6741      +17     
- Misses       1678     1683       +5
Impacted Files Coverage Δ
src/gluonts/distribution/distribution.py 82.89% <100%> (+0.95%) ⬆️
src/gluonts/distribution/gaussian.py 93.44% <100%> (+0.33%) ⬆️
src/gluonts/distribution/binned.py 95.41% <66.66%> (-0.82%) ⬇️
src/gluonts/distribution/student_t.py 98.3% <75%> (-1.7%) ⬇️
src/gluonts/distribution/uniform.py 91.07% <75%> (-1.39%) ⬇️
src/gluonts/distribution/neg_binomial.py 96.92% <75%> (-1.47%) ⬇️
src/gluonts/distribution/laplace.py 98.3% <75%> (-1.7%) ⬇️

@lostella lostella merged commit ac782de into awslabs:master Oct 17, 2019
@lostella lostella deleted the distribution-slice branch October 17, 2019 09:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants