-
Notifications
You must be signed in to change notification settings - Fork 750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add slice_axis methods to Distribution #397
Conversation
sliced_distr = self.__class__( | ||
*[a.slice_axis(axis, begin, end) for a in self.args] | ||
) | ||
assert isinstance(sliced_distr, type(self)) | ||
return sliced_distr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could use type(self)
instead of self.__class__
and return the object directly. However, that way mypy complains that "too many constructor arguments are passed", so it probably infers that type(self) == Distribution
instead of the most specific class. This way instead, I get to fool mypy (a-ha!).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the problem is that mypy cannot infer the number of arguments of the specific distribution class. Distribution
takes no arguments, so that's why it probably complains.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
@@ -174,6 +174,10 @@ def s(bin_probs): | |||
|
|||
return _sample_multiple(s, self.bin_probs, num_samples=num_samples) | |||
|
|||
@property | |||
def args(self) -> List: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we also add an abstract definition in the base class?
However, I don't like this args
thing. You impose a constraint on the arguments of all distributions that they are all sliceable. Also, by using a list order is important and can easily mixed up.
As an alternative we could have a "dimensions" attribute which returns the name of the fields in question and then use them to construct a new instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The property is defined in the base class already.
You impose a constraint on the arguments of all distributions that they are all sliceable.
I agree, but that's more about the slice_axis
method (we should be careful there I guess) than args
.
Also, by using a list order is important and can easily mixed up.
Do you suggest using a dictionary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's fine for now, although I have some concerns about this axis thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A less ambiguous solution (but a more verbose one) would be to define the slice_axis
method for each distribution. A compromise could be to have this as fallback definition, and specialize slice_axis
in case the args
property doesn't work for some reason.
@@ -174,6 +174,10 @@ def s(bin_probs): | |||
|
|||
return _sample_multiple(s, self.bin_probs, num_samples=num_samples) | |||
|
|||
@property | |||
def args(self) -> List: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's fine for now, although I have some concerns about this axis thing.
Codecov Report
@@ Coverage Diff @@
## master #397 +/- ##
==========================================
- Coverage 80.02% 80.02% -0.01%
==========================================
Files 146 146
Lines 8402 8424 +22
==========================================
+ Hits 6724 6741 +17
- Misses 1678 1683 +5
|
Description of changes: This allows to slice distributions across arbitrary axes, using the basically the same signatures as mx.nd.slice and mx.nd.slice_axis.
More tests/distributions to come, feedback wanted.
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.