Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Gluon.probability #18403

Merged
merged 163 commits into from
Jul 7, 2020
Merged
Show file tree
Hide file tree
Changes from 149 commits
Commits
Show all changes
163 commits
Select commit Hold shift + click to select a range
4207f3c
package created
xidulu Dec 16, 2019
de160a6
mvn WIP
xidulu Dec 17, 2019
85bcd07
normal wip, to be tested
xidulu Dec 18, 2019
49bb79e
update
Dec 18, 2019
a172adc
docstring added, normal mostly done
xidulu Dec 19, 2019
b8458fd
update docstring and naming
Dec 20, 2019
2e45472
add test file
xidulu Dec 20, 2019
550dbae
Bernoulli WIP
xidulu Dec 20, 2019
5928490
bernoulli wip
xidulu Dec 23, 2019
954527c
bernoulli doc done
xidulu Dec 23, 2019
9f8876a
dense variational WIP
xidulu Dec 27, 2019
e7329ca
add kl infra
xidulu Dec 31, 2019
64d9703
update from upstream
xidulu Dec 31, 2019
8e4afeb
implement normal kl method
xidulu Dec 31, 2019
733ddd3
refactor kl
xidulu Jan 4, 2020
138b2b4
add not implemented handling, rename kl_storage
xidulu Jan 5, 2020
b37a749
add abstract method and Categorical class
xidulu Jan 5, 2020
791916c
rewrite logit2prob prob2logit for multiclass support
xidulu Jan 6, 2020
d8b4228
Merge remote-tracking branch 'upstream/master' into distribution_dev
xidulu Jan 6, 2020
7b1e7c7
normal broadcast_to implemented
xidulu Jan 6, 2020
f24a450
categorical mostly done
xidulu Jan 7, 2020
2369f35
update distribution
Jan 7, 2020
94de579
update distributions/utils.py
Jan 7, 2020
cd578a9
Merge remote-tracking branch 'origin/distribution_dev' into distribut…
xidulu Jan 8, 2020
1ff9d05
add dot ahead of import
xidulu Jan 8, 2020
255aaef
update from upstream
xidulu Jan 13, 2020
2709680
fix normal F
xidulu Jan 13, 2020
7f8b91f
Update from upstream
xidulu Jan 14, 2020
72dfbde
bernoulli, normal brief tests implemented
xidulu Jan 15, 2020
43e5076
add hybridize tests
xidulu Jan 15, 2020
5f862cc
transformation infras done
xidulu Jan 18, 2020
f6c0446
affine transformation, implemented tested
xidulu Jan 20, 2020
e307494
add tests cases
xidulu Jan 20, 2020
74a8c62
add sum_right_most
xidulu Jan 21, 2020
a4cd45e
fix get F bug
xidulu Jan 22, 2020
a863a87
compose transform implemented, tested
xidulu Jan 23, 2020
d214b2e
fix
Jan 23, 2020
3d5cfe3
add event_dim
xidulu Jan 24, 2020
a7c2172
update from upstream
xidulu Jan 24, 2020
87eb50f
fetch mvn from upstremm
xidulu Jan 24, 2020
380ed79
clean code, implement normal cdf and tests
xidulu Jan 24, 2020
d535af3
constraint in bernoulli done
xidulu Jan 27, 2020
a97937b
fix constraint
xidulu Jan 28, 2020
c2799df
finish half normal
xidulu Jan 28, 2020
6c4b84a
add cached_property
Jan 29, 2020
8fa160c
add test on cached_property
Jan 29, 2020
2c6dd02
add more features to distribution and constratins
Jan 30, 2020
bd6277f
change constraint
Feb 6, 2020
f506014
fix bernoulli
Feb 6, 2020
4dedd8f
add independent
Feb 11, 2020
e8d3ebd
add independent tests
xidulu Feb 12, 2020
4ef62db
update naming of cached_property
Feb 12, 2020
9a96d74
revert
Feb 12, 2020
efae056
add constraints
Feb 12, 2020
8a2d7d1
add Cat
Feb 13, 2020
44b7e14
add Stack for imperative mode
Feb 13, 2020
1462629
add Stack for imperative mode
Feb 13, 2020
fcd32a1
add bernoulli entropy
xidulu Feb 13, 2020
fb57a57
Merge remote-tracking branch 'upstream/master' into distribution_dev
xidulu Feb 15, 2020
c0b33ae
categorical WIP
xidulu Feb 17, 2020
cc7aa0b
categorical sampling implemented
xidulu Feb 18, 2020
8a67be6
finish categorical log_prob, sampling
xidulu Feb 22, 2020
16f5d07
enumerate_support finished
xidulu Feb 23, 2020
c9cde78
polish StochasticBlock, add test
xidulu Feb 23, 2020
cb8362e
add test for stochastic sequential
xidulu Feb 23, 2020
f0f8f3a
clean loss list in __call__
xidulu Feb 24, 2020
730d0d6
fix affine, implement sigmoid, softmax
xidulu Feb 25, 2020
1b8354a
add gumbel, relaxed bernoulli
xidulu Feb 26, 2020
9567a35
relaxed one-hot sampling implemented
xidulu Feb 27, 2020
105936b
Merge remote-tracking branch 'upstream/master' into distribution_dev
xidulu Feb 29, 2020
13714e4
gamma done
xidulu Feb 29, 2020
61dba33
gamma, dirichlet implemented
xidulu Mar 1, 2020
816b01d
beta done
xidulu Mar 2, 2020
cf183ee
gumbel softmax log-likelihood implemented
xidulu Mar 3, 2020
ebc3099
Merge remote-tracking branch 'upstream/master' into distribution_dev
xidulu Mar 4, 2020
2453489
refactor tests, implement exponential, fix compose transform
xidulu Mar 4, 2020
8a3a7d7
weibull implemented, transformed distribution cdf icdf added
xidulu Mar 5, 2020
35d4fee
pareto implemented
xidulu Mar 6, 2020
0c4a282
uniform wip
xidulu Mar 8, 2020
8e659a4
uniform done
xidulu Mar 8, 2020
b8d654e
rewrite lgamma, implement chi2
xidulu Mar 8, 2020
d4997d8
fix chi2 scale
xidulu Mar 9, 2020
e85dec3
F distributiion done
xidulu Mar 9, 2020
280869c
t implemented
xidulu Mar 11, 2020
16dca53
fix tiny problem
xidulu Mar 13, 2020
00b76ec
cauchy done
xidulu Mar 14, 2020
7f1da0f
add half cauchy
xidulu Mar 16, 2020
3d91201
multinomial done, tests to be added
xidulu Mar 16, 2020
ab9e812
add multinomial test
xidulu Mar 17, 2020
aa4b9ca
laplace done
xidulu Mar 17, 2020
33089fb
MVN done, tests todo
xidulu Mar 20, 2020
9141f20
mvn polished
xidulu Mar 21, 2020
6c37057
fix a few precison issues
xidulu Mar 23, 2020
f8de23d
add erf, erfinv unified api and learnable transform
xidulu Mar 24, 2020
fff2a49
fix mvn attribute check
xidulu Mar 27, 2020
2c805ea
MVN done
xidulu Mar 27, 2020
dd5119c
poisson done
xidulu Mar 30, 2020
ec31a4e
hack poisson for size support
xidulu Mar 31, 2020
dfcbbca
geometric finished
xidulu Apr 5, 2020
14531a1
negative binomial done
xidulu Apr 5, 2020
156b526
binomial done
xidulu Apr 7, 2020
f13f6cb
implement some kl
xidulu Apr 13, 2020
31d0f6d
add more kl
xidulu Apr 14, 2020
35e567f
refactor kl test
xidulu Apr 20, 2020
a4d0f42
add more kl
xidulu Apr 21, 2020
3b839dd
binomial kl todo
xidulu Apr 23, 2020
304d7cd
update from upstream
xidulu Apr 26, 2020
76efc48
change constraint logical op implement
xidulu Apr 26, 2020
8d2efe5
implement gamma entropy
xidulu Apr 26, 2020
c330eab
finish beta dirchlet entropy
xidulu Apr 27, 2020
32c2184
finishi all entropy
xidulu Apr 27, 2020
f13fd19
kl finished
xidulu Apr 28, 2020
907462e
add constraint test
xidulu Apr 30, 2020
6f47521
domain map done
xidulu May 4, 2020
67d0158
remove bayesian dense
xidulu May 5, 2020
bdb732f
fix tiny problems
xidulu May 5, 2020
9a96d54
add kl uniform normal
xidulu May 5, 2020
9966c1b
add kl tests
xidulu May 6, 2020
b7fa0a6
acquire patch from upstream
xidulu May 6, 2020
cdfd0d5
add some doc
xidulu May 6, 2020
369d223
finish doc
xidulu May 7, 2020
367aa7c
refactor kl test(WIP)
xidulu May 10, 2020
5f1be5b
add more kl, fix float32 underflow issue
xidulu May 12, 2020
648d639
make sampling more stable
xidulu May 13, 2020
3bf65cd
handle inconsistent mode
xidulu May 14, 2020
add173d
replace boolean idx with np.where
xidulu May 17, 2020
bc7f856
fix file name
xidulu May 19, 2020
7982d29
add more doc
xidulu May 19, 2020
9effea9
add constraint check
xidulu May 25, 2020
d5a4449
add half_normal/cauchy pdf cdf support check
xidulu May 25, 2020
5c00c06
Merge remote-tracking branch 'upstream/master' into distribution_dev
xidulu May 26, 2020
9152e0d
fix import problem
xidulu May 26, 2020
3b15cf2
change nosetest to pytest
xidulu May 27, 2020
f6930b5
remove buggy lines
xidulu May 28, 2020
6c45745
change alias register path
xidulu May 30, 2020
c1eb3e6
attempt to fix ci
xidulu Jun 1, 2020
a390a56
fix lint, change a few tests
xidulu Jun 4, 2020
d7e33ca
fix lint
xidulu Jun 8, 2020
25cc066
modify hybrid sequential
xidulu Jun 9, 2020
c0fb8ce
fix lint
xidulu Jun 9, 2020
68cb6da
change import order
xidulu Jun 9, 2020
fef4c8e
add test gluon probability v2
xidulu Jun 15, 2020
e85e67d
fix hybridize flag
xidulu Jun 15, 2020
d356a38
change implementation of stochastic block
xidulu Jun 16, 2020
894a773
fix lint
xidulu Jun 16, 2020
d74e380
fix comments
xidulu Jun 23, 2020
fd4a418
Merge remote-tracking branch 'upstream/master' into distribution_dev
xidulu Jun 23, 2020
4ad89e5
fix block
xidulu Jun 23, 2020
c8f121b
modify domain map
xidulu Jun 24, 2020
da86be9
add raises for improper add_loss
xidulu Jun 30, 2020
88b6adf
add raises for improper add_loss
xidulu Jun 30, 2020
3fab70c
add extra cases
xidulu Jun 30, 2020
0291ed4
change collectLoss decorator to mandatory
xidulu Jul 1, 2020
fa4a06c
skip stochastic block tests
xidulu Jul 2, 2020
53089dc
Merge remote-tracking branch 'upstream/master' into distribution_dev
xidulu Jul 2, 2020
f77b1c6
remove test cases
xidulu Jul 2, 2020
c50bad0
put gpu tests back
xidulu Jul 3, 2020
6bdaf9f
add test_gluon_stochastic_block back
xidulu Jul 3, 2020
9689c59
remove export test
xidulu Jul 3, 2020
5007a04
put a test back
xidulu Jul 3, 2020
c018fe2
tiny refactor
xidulu Jul 4, 2020
974a9a7
add memory leak flag
xidulu Jul 5, 2020
9b1564c
small changes
xidulu Jul 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/mxnet/gluon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,5 @@
from . import model_zoo

from . import contrib

from . import probability
26 changes: 26 additions & 0 deletions python/mxnet/gluon/probability/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import
"""Probability module"""

from .block import *

from .distributions import *

from .transformation import *
22 changes: 22 additions & 0 deletions python/mxnet/gluon/probability/block/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import
"""Stochastic block."""

from .stochastic_block import *
127 changes: 127 additions & 0 deletions python/mxnet/gluon/probability/block/stochastic_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=abstract-method
"""Stochastic block class."""
__all__ = ['StochasticBlock', 'StochasticSequential']

from functools import wraps
from ...block import HybridBlock
from ...utils import _indent


class StochasticBlock(HybridBlock):
"""`StochasticBlock` extends `HybridBlock` to support accumulating loss
in the forward phase, which is extremely useful in building Bayesian Neural Network,
where the loss function is composed of a classification loss and a KL loss.

"""

def __init__(self, **kwargs):
super(StochasticBlock, self).__init__(**kwargs)
self._losses = []
self._losscache = []

def add_loss(self, loss):
self._losscache.append(loss)

@staticmethod
def collectLoss(func):
"""To accumulate loss during the forward phase, one could first decorate
hybrid_forward with `StochasticBlock.collectLoss,
and then collect the loss tensor `x` by calling self.add_loss(x).
For example, in the following forward function,
we generate samples from a Gaussian parameterized by `loc` and `scale` and
accumulate the KL-divergence between it and its prior into the block's loss storage.:
@StochasticBlock.collectLoss
def hybrid_forward(self, F, loc, scale):
qz = mgp.Normal(loc, scale)
# prior
pz = mgp.Normal(F.np.zeros_like(loc), F.np.ones_like(scale))
self.add_loss(mgp.kl_divergence(qz, pz))
return qz.sample()
"""
@wraps(func)
def inner(self, *args, **kwargs):
# Loss from hybrid_forward
func_out = func(self, *args, **kwargs)
collected_loss = self._losscache
self._losscache = []
return (func_out, collected_loss)

return inner

def __call__(self, *args, **kwargs):
# pylint: disable=arguments-differ
out = super().__call__(*args, **kwargs)
self._losses.extend(out[1])
return out[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to assume that @StochasticBlock.collectLoss is used by the user for his forward function? If so: currently it's not mandatory to use @StochasticBlock.collectLoss. If it's not mandatory, you need to also handle the case where users don't specify @StochasticBlock.collectLoss decorator. Could you help to clarify?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@StochasticBlock.collectLoss is only needed when the users want to compute per-block loss such as KL divergence. If such loss is not demanded, the users can simply ignore this decorator or use default HybridBlock.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying @szhengac. As per the current implementation, if user ignore this decorator and don't add it when they don't need it, an error will be raised as the __call__ function always expects two outputs. If the error is correct and expected (ie. adding @StochasticBlock.collectLoss is mandatory), why not remove the decorator for simplicity and handle it's implementation by default (in the __call__)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is not needed, the out[1] will be an empty list, so there is no error incurred here. One reason that we use this decorator is that we also need to use it in StochasticSequential for collecting losses over multiple blocks. And, as self.add_loss(kl) is called inside forward or hybrid_forward, I think we cannot move the implementation details to __call__.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out[1] will only be an empty list if the user specifies @StochasticBlock.collectLoss. If they don't specify the decorator, out[1] will not exist and there'd be an error. We should have a good error message or at least document that users must use @StochasticBlock.collectLoss

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, you are right. We need to think about it. Thanks for pointing this out.

Copy link
Contributor Author

@xidulu xidulu Jun 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leezu
Thanks for pointing this out, an error should be raised when users is attempting to call self.add_loss in a function not decorated by collectLoss. I will look into this issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @xidulu. Also note that this issue is present when users don't call collectLoss. That's because the call method here is always called when users use StochasticBlock

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leezu

I add two checks here: https://github.com/apache/incubator-mxnet/pull/18403/files#diff-85458cf5116b137da8148bf5b38bcfaeR74
https://github.com/apache/incubator-mxnet/pull/18403/files#diff-85458cf5116b137da8148bf5b38bcfaeR78

To make it clearer, I list several possible situations:

  1. Users call add_loss inside functions decorated by CollectLoss, add_loss appends losses into _losscache, _losscache would then get cleared in CollectLoss, len(_losscache) becomes 0 when call is invoked.

  2. Users call add_loss without using CollectLoss, add_loss appends losses into _losscache, _losscache still contains value when entering call, in this case, a exception will be raised.

  3. Users use CollectLoss without calling add_loss, self._losses = out[1] = []

  4. Users use StochasticBlock without calling CollectLoss or add_loss, len(out) == 1, out[1] will not be accessed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leezu
Update: I made further changes here to avoid confusion. Now the users are forced to use to collectLoss decorator in all cases, otherwise an exception would be raised.


@property
def losses(self):
return self._losses


class StochasticSequential(StochasticBlock):
"""Stack StochasticBlock sequentially.
"""

def __init__(self, **kwargs):
super(StochasticSequential, self).__init__(**kwargs)
self._layers = []

def add(self, *blocks):
"""Adds block on top of the stack."""
for block in blocks:
self._layers.append(block)
self.register_child(block)

@StochasticBlock.collectLoss
def forward(self, x, *args):
# pylint: disable=arguments-differ
for block in self._children.values():
x = block()(x, *args)
args = []
if isinstance(x, (tuple, list)):
args = x[1:]
x = x[0]
if args:
x = tuple([x] + list(args))
for block in self._layers:
if hasattr(block, '_losses'):
self.add_loss(block._losses)
return x

def __repr__(self):
s = '{name}(\n{modstr}\n)'
modstr = '\n'.join([' ({key}): {block}'.format(key=key,
block=_indent(block().__repr__(), 2))
for key, block in self._children.items()])
return s.format(name=self.__class__.__name__, modstr=modstr)

def __getitem__(self, key):
layers = list(self._children.values())[key]
xidulu marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(layers, list):
net = type(self)()
net.add(*(l() for l in layers))
return net
else:
return layers()

def __len__(self):
return len(self._children)
86 changes: 86 additions & 0 deletions python/mxnet/gluon/probability/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import
"""Distribution classes."""

from .distribution import *

from .exp_family import *

from .exponential import *

from .weibull import *

from .pareto import *

from .uniform import *

from .normal import *

from .laplace import *

from .cauchy import *

from .half_cauchy import *

from .poisson import *

from .geometric import *

from .negative_binomial import *

from .gamma import *

from .dirichlet import *

from .beta import *

from .chi2 import *

from .fishersnedecor import *

from .studentT import *

from .half_normal import *

from .independent import *

from .bernoulli import *

from .binomial import *

from .relaxed_bernoulli import *

from .gumbel import *

from .categorical import *

from .one_hot_categorical import *

from .relaxed_one_hot_categorical import *

from .multinomial import *

from .multivariate_normal import *

from .transformed_distribution import *

from .divergence import *

from .utils import *
Loading