This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Gluon.probability #18403
Merged
Merged
Gluon.probability #18403
Changes from 149 commits
Commits
Show all changes
163 commits
Select commit
Hold shift + click to select a range
4207f3c
package created
xidulu de160a6
mvn WIP
xidulu 85bcd07
normal wip, to be tested
xidulu 49bb79e
update
a172adc
docstring added, normal mostly done
xidulu b8458fd
update docstring and naming
2e45472
add test file
xidulu 550dbae
Bernoulli WIP
xidulu 5928490
bernoulli wip
xidulu 954527c
bernoulli doc done
xidulu 9f8876a
dense variational WIP
xidulu e7329ca
add kl infra
xidulu 64d9703
update from upstream
xidulu 8e4afeb
implement normal kl method
xidulu 733ddd3
refactor kl
xidulu 138b2b4
add not implemented handling, rename kl_storage
xidulu b37a749
add abstract method and Categorical class
xidulu 791916c
rewrite logit2prob prob2logit for multiclass support
xidulu d8b4228
Merge remote-tracking branch 'upstream/master' into distribution_dev
xidulu 7b1e7c7
normal broadcast_to implemented
xidulu f24a450
categorical mostly done
xidulu 2369f35
update distribution
94de579
update distributions/utils.py
cd578a9
Merge remote-tracking branch 'origin/distribution_dev' into distribut…
xidulu 1ff9d05
add dot ahead of import
xidulu 255aaef
update from upstream
xidulu 2709680
fix normal F
xidulu 7f8b91f
Update from upstream
xidulu 72dfbde
bernoulli, normal brief tests implemented
xidulu 43e5076
add hybridize tests
xidulu 5f862cc
transformation infras done
xidulu f6c0446
affine transformation, implemented tested
xidulu e307494
add tests cases
xidulu 74a8c62
add sum_right_most
xidulu a4cd45e
fix get F bug
xidulu a863a87
compose transform implemented, tested
xidulu d214b2e
fix
3d5cfe3
add event_dim
xidulu a7c2172
update from upstream
xidulu 87eb50f
fetch mvn from upstremm
xidulu 380ed79
clean code, implement normal cdf and tests
xidulu d535af3
constraint in bernoulli done
xidulu a97937b
fix constraint
xidulu c2799df
finish half normal
xidulu 6c4b84a
add cached_property
8fa160c
add test on cached_property
2c6dd02
add more features to distribution and constratins
bd6277f
change constraint
f506014
fix bernoulli
4dedd8f
add independent
e8d3ebd
add independent tests
xidulu 4ef62db
update naming of cached_property
9a96d74
revert
efae056
add constraints
8a2d7d1
add Cat
44b7e14
add Stack for imperative mode
1462629
add Stack for imperative mode
fcd32a1
add bernoulli entropy
xidulu fb57a57
Merge remote-tracking branch 'upstream/master' into distribution_dev
xidulu c0b33ae
categorical WIP
xidulu cc7aa0b
categorical sampling implemented
xidulu 8a67be6
finish categorical log_prob, sampling
xidulu 16f5d07
enumerate_support finished
xidulu c9cde78
polish StochasticBlock, add test
xidulu cb8362e
add test for stochastic sequential
xidulu f0f8f3a
clean loss list in __call__
xidulu 730d0d6
fix affine, implement sigmoid, softmax
xidulu 1b8354a
add gumbel, relaxed bernoulli
xidulu 9567a35
relaxed one-hot sampling implemented
xidulu 105936b
Merge remote-tracking branch 'upstream/master' into distribution_dev
xidulu 13714e4
gamma done
xidulu 61dba33
gamma, dirichlet implemented
xidulu 816b01d
beta done
xidulu cf183ee
gumbel softmax log-likelihood implemented
xidulu ebc3099
Merge remote-tracking branch 'upstream/master' into distribution_dev
xidulu 2453489
refactor tests, implement exponential, fix compose transform
xidulu 8a3a7d7
weibull implemented, transformed distribution cdf icdf added
xidulu 35d4fee
pareto implemented
xidulu 0c4a282
uniform wip
xidulu 8e659a4
uniform done
xidulu b8d654e
rewrite lgamma, implement chi2
xidulu d4997d8
fix chi2 scale
xidulu e85dec3
F distributiion done
xidulu 280869c
t implemented
xidulu 16dca53
fix tiny problem
xidulu 00b76ec
cauchy done
xidulu 7f1da0f
add half cauchy
xidulu 3d91201
multinomial done, tests to be added
xidulu ab9e812
add multinomial test
xidulu aa4b9ca
laplace done
xidulu 33089fb
MVN done, tests todo
xidulu 9141f20
mvn polished
xidulu 6c37057
fix a few precison issues
xidulu f8de23d
add erf, erfinv unified api and learnable transform
xidulu fff2a49
fix mvn attribute check
xidulu 2c805ea
MVN done
xidulu dd5119c
poisson done
xidulu ec31a4e
hack poisson for size support
xidulu dfcbbca
geometric finished
xidulu 14531a1
negative binomial done
xidulu 156b526
binomial done
xidulu f13f6cb
implement some kl
xidulu 31d0f6d
add more kl
xidulu 35e567f
refactor kl test
xidulu a4d0f42
add more kl
xidulu 3b839dd
binomial kl todo
xidulu 304d7cd
update from upstream
xidulu 76efc48
change constraint logical op implement
xidulu 8d2efe5
implement gamma entropy
xidulu c330eab
finish beta dirchlet entropy
xidulu 32c2184
finishi all entropy
xidulu f13fd19
kl finished
xidulu 907462e
add constraint test
xidulu 6f47521
domain map done
xidulu 67d0158
remove bayesian dense
xidulu bdb732f
fix tiny problems
xidulu 9a96d54
add kl uniform normal
xidulu 9966c1b
add kl tests
xidulu b7fa0a6
acquire patch from upstream
xidulu cdfd0d5
add some doc
xidulu 369d223
finish doc
xidulu 367aa7c
refactor kl test(WIP)
xidulu 5f1be5b
add more kl, fix float32 underflow issue
xidulu 648d639
make sampling more stable
xidulu 3bf65cd
handle inconsistent mode
xidulu add173d
replace boolean idx with np.where
xidulu bc7f856
fix file name
xidulu 7982d29
add more doc
xidulu 9effea9
add constraint check
xidulu d5a4449
add half_normal/cauchy pdf cdf support check
xidulu 5c00c06
Merge remote-tracking branch 'upstream/master' into distribution_dev
xidulu 9152e0d
fix import problem
xidulu 3b15cf2
change nosetest to pytest
xidulu f6930b5
remove buggy lines
xidulu 6c45745
change alias register path
xidulu c1eb3e6
attempt to fix ci
xidulu a390a56
fix lint, change a few tests
xidulu d7e33ca
fix lint
xidulu 25cc066
modify hybrid sequential
xidulu c0fb8ce
fix lint
xidulu 68cb6da
change import order
xidulu fef4c8e
add test gluon probability v2
xidulu e85e67d
fix hybridize flag
xidulu d356a38
change implementation of stochastic block
xidulu 894a773
fix lint
xidulu d74e380
fix comments
xidulu fd4a418
Merge remote-tracking branch 'upstream/master' into distribution_dev
xidulu 4ad89e5
fix block
xidulu c8f121b
modify domain map
xidulu da86be9
add raises for improper add_loss
xidulu 88b6adf
add raises for improper add_loss
xidulu 3fab70c
add extra cases
xidulu 0291ed4
change collectLoss decorator to mandatory
xidulu fa4a06c
skip stochastic block tests
xidulu 53089dc
Merge remote-tracking branch 'upstream/master' into distribution_dev
xidulu f77b1c6
remove test cases
xidulu c50bad0
put gpu tests back
xidulu 6bdaf9f
add test_gluon_stochastic_block back
xidulu 9689c59
remove export test
xidulu 5007a04
put a test back
xidulu c018fe2
tiny refactor
xidulu 974a9a7
add memory leak flag
xidulu 9b1564c
small changes
xidulu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,3 +40,5 @@ | |
from . import model_zoo | ||
|
||
from . import contrib | ||
|
||
from . import probability |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# coding: utf-8 | ||
# pylint: disable=wildcard-import | ||
"""Probability module""" | ||
|
||
from .block import * | ||
|
||
from .distributions import * | ||
|
||
from .transformation import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# coding: utf-8 | ||
# pylint: disable=wildcard-import | ||
"""Stochastic block.""" | ||
|
||
from .stochastic_block import * |
127 changes: 127 additions & 0 deletions
127
python/mxnet/gluon/probability/block/stochastic_block.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# coding: utf-8 | ||
# pylint: disable=abstract-method | ||
"""Stochastic block class.""" | ||
__all__ = ['StochasticBlock', 'StochasticSequential'] | ||
|
||
from functools import wraps | ||
from ...block import HybridBlock | ||
from ...utils import _indent | ||
|
||
|
||
class StochasticBlock(HybridBlock): | ||
"""`StochasticBlock` extends `HybridBlock` to support accumulating loss | ||
in the forward phase, which is extremely useful in building Bayesian Neural Network, | ||
where the loss function is composed of a classification loss and a KL loss. | ||
|
||
""" | ||
|
||
def __init__(self, **kwargs): | ||
super(StochasticBlock, self).__init__(**kwargs) | ||
self._losses = [] | ||
self._losscache = [] | ||
|
||
def add_loss(self, loss): | ||
self._losscache.append(loss) | ||
|
||
@staticmethod | ||
def collectLoss(func): | ||
"""To accumulate loss during the forward phase, one could first decorate | ||
hybrid_forward with `StochasticBlock.collectLoss, | ||
and then collect the loss tensor `x` by calling self.add_loss(x). | ||
For example, in the following forward function, | ||
we generate samples from a Gaussian parameterized by `loc` and `scale` and | ||
accumulate the KL-divergence between it and its prior into the block's loss storage.: | ||
@StochasticBlock.collectLoss | ||
def hybrid_forward(self, F, loc, scale): | ||
qz = mgp.Normal(loc, scale) | ||
# prior | ||
pz = mgp.Normal(F.np.zeros_like(loc), F.np.ones_like(scale)) | ||
self.add_loss(mgp.kl_divergence(qz, pz)) | ||
return qz.sample() | ||
""" | ||
@wraps(func) | ||
def inner(self, *args, **kwargs): | ||
# Loss from hybrid_forward | ||
func_out = func(self, *args, **kwargs) | ||
collected_loss = self._losscache | ||
self._losscache = [] | ||
return (func_out, collected_loss) | ||
|
||
return inner | ||
|
||
def __call__(self, *args, **kwargs): | ||
# pylint: disable=arguments-differ | ||
out = super().__call__(*args, **kwargs) | ||
self._losses.extend(out[1]) | ||
return out[0] | ||
|
||
@property | ||
def losses(self): | ||
return self._losses | ||
|
||
|
||
class StochasticSequential(StochasticBlock): | ||
"""Stack StochasticBlock sequentially. | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
super(StochasticSequential, self).__init__(**kwargs) | ||
self._layers = [] | ||
|
||
def add(self, *blocks): | ||
"""Adds block on top of the stack.""" | ||
for block in blocks: | ||
self._layers.append(block) | ||
self.register_child(block) | ||
|
||
@StochasticBlock.collectLoss | ||
def forward(self, x, *args): | ||
# pylint: disable=arguments-differ | ||
for block in self._children.values(): | ||
x = block()(x, *args) | ||
args = [] | ||
if isinstance(x, (tuple, list)): | ||
args = x[1:] | ||
x = x[0] | ||
if args: | ||
x = tuple([x] + list(args)) | ||
for block in self._layers: | ||
if hasattr(block, '_losses'): | ||
self.add_loss(block._losses) | ||
return x | ||
|
||
def __repr__(self): | ||
s = '{name}(\n{modstr}\n)' | ||
modstr = '\n'.join([' ({key}): {block}'.format(key=key, | ||
block=_indent(block().__repr__(), 2)) | ||
for key, block in self._children.items()]) | ||
return s.format(name=self.__class__.__name__, modstr=modstr) | ||
|
||
def __getitem__(self, key): | ||
layers = list(self._children.values())[key] | ||
xidulu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(layers, list): | ||
net = type(self)() | ||
net.add(*(l() for l in layers)) | ||
return net | ||
else: | ||
return layers() | ||
|
||
def __len__(self): | ||
return len(self._children) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# coding: utf-8 | ||
# pylint: disable=wildcard-import | ||
"""Distribution classes.""" | ||
|
||
from .distribution import * | ||
|
||
from .exp_family import * | ||
|
||
from .exponential import * | ||
|
||
from .weibull import * | ||
|
||
from .pareto import * | ||
|
||
from .uniform import * | ||
|
||
from .normal import * | ||
|
||
from .laplace import * | ||
|
||
from .cauchy import * | ||
|
||
from .half_cauchy import * | ||
|
||
from .poisson import * | ||
|
||
from .geometric import * | ||
|
||
from .negative_binomial import * | ||
|
||
from .gamma import * | ||
|
||
from .dirichlet import * | ||
|
||
from .beta import * | ||
|
||
from .chi2 import * | ||
|
||
from .fishersnedecor import * | ||
|
||
from .studentT import * | ||
|
||
from .half_normal import * | ||
|
||
from .independent import * | ||
|
||
from .bernoulli import * | ||
|
||
from .binomial import * | ||
|
||
from .relaxed_bernoulli import * | ||
|
||
from .gumbel import * | ||
|
||
from .categorical import * | ||
|
||
from .one_hot_categorical import * | ||
|
||
from .relaxed_one_hot_categorical import * | ||
|
||
from .multinomial import * | ||
|
||
from .multivariate_normal import * | ||
|
||
from .transformed_distribution import * | ||
|
||
from .divergence import * | ||
|
||
from .utils import * |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to assume that
@StochasticBlock.collectLoss
is used by the user for his forward function? If so: currently it's not mandatory to use@StochasticBlock.collectLoss
. If it's not mandatory, you need to also handle the case where users don't specify@StochasticBlock.collectLoss
decorator. Could you help to clarify?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@StochasticBlock.collectLoss is only needed when the users want to compute per-block loss such as KL divergence. If such loss is not demanded, the users can simply ignore this decorator or use default HybridBlock.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying @szhengac. As per the current implementation, if user ignore this decorator and don't add it when they don't need it, an error will be raised as the
__call__
function always expects two outputs. If the error is correct and expected (ie. adding@StochasticBlock.collectLoss
is mandatory), why not remove the decorator for simplicity and handle it's implementation by default (in the__call__
)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it is not needed, the
out[1]
will be an empty list, so there is no error incurred here. One reason that we use this decorator is that we also need to use it inStochasticSequential
for collecting losses over multiple blocks. And, asself.add_loss(kl)
is called insideforward
orhybrid_forward
, I think we cannot move the implementation details to__call__
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out[1]
will only be an empty list if the user specifies@StochasticBlock.collectLoss
. If they don't specify the decorator,out[1]
will not exist and there'd be an error. We should have a good error message or at least document that users must use@StochasticBlock.collectLoss
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, you are right. We need to think about it. Thanks for pointing this out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@leezu
Thanks for pointing this out, an error should be raised when users is attempting to call self.add_loss in a function not decorated by
collectLoss
. I will look into this issue.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @xidulu. Also note that this issue is present when users don't call collectLoss. That's because the call method here is always called when users use StochasticBlock
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@leezu
I add two checks here: https://github.com/apache/incubator-mxnet/pull/18403/files#diff-85458cf5116b137da8148bf5b38bcfaeR74
https://github.com/apache/incubator-mxnet/pull/18403/files#diff-85458cf5116b137da8148bf5b38bcfaeR78
To make it clearer, I list several possible situations:
Users call add_loss inside functions decorated by CollectLoss, add_loss appends losses into _losscache, _losscache would then get cleared in CollectLoss, len(_losscache) becomes 0 when call is invoked.
Users call add_loss without using CollectLoss, add_loss appends losses into _losscache, _losscache still contains value when entering call, in this case, a exception will be raised.
Users use CollectLoss without calling add_loss, self._losses = out[1] = []
Users use StochasticBlock without calling CollectLoss or add_loss, len(out) == 1, out[1] will not be accessed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@leezu
Update: I made further changes here to avoid confusion. Now the users are forced to use to collectLoss decorator in all cases, otherwise an exception would be raised.