-
-
Notifications
You must be signed in to change notification settings - Fork 153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Student's t RandomVariable
#1211
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #1211 +/- ##
=======================================
Coverage 74.10% 74.11%
=======================================
Files 174 174
Lines 48624 48636 +12
Branches 10351 10351
=======================================
+ Hits 36035 36047 +12
Misses 10301 10301
Partials 2288 2288
|
RandomVariable
RandomVariable
aesara/tensor/random/basic.py
Outdated
@@ -431,6 +431,45 @@ def rng_fn_scipy(cls, rng, shape, scale, size): | |||
gamma = GammaRV() | |||
|
|||
|
|||
class StandardGammaRV(GammaRV): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we only intend to add standard_gamma
for NumPy interface compatibility, then the GammaRV
Op
alone should suffice, no? For instance, standard_gamma
could be an Op
constructor function that effectively removes the shape and rate arguments from GammaRV
.
As always, we want to avoid adding new Op
s whenever we reasonably can.
N.B. This is an example of the concern stated in aesara-devs/aemcmc#67 (comment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the issue is that you can't use RandomStream if you don't have a specific Op. That's why the standard_nornal was made a subclass as well IIRC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we only intend to add
standard_gamma
for NumPy interface compatibility, then theGammaRV
Op
alone should suffice, no? For instance,standard_gamma
could be anOp
constructor function that effectively removes the shape and rate arguments fromGammaRV
.As always, we want to avoid adding new
Op
s whenever we reasonably can.N.B. This is an example of the concern stated in aesara-devs/aemcmc#67 (comment).
I agree with all this.
@ricardoV94 we can monkey patch the base classes. I'm not a big fan of this approach, but it should work:
import aesara.tensor as at
from aesara.tensor.random.basic import NormalRV
def create_standard_normal():
def standard_call(self, shape, size=None, **kwargs):
return self.general_call(0., 1., size, **kwargs)
RV = NormalRV
RV.general_call = RV.__call__
RV.__call__ = standard_call
return RV()
standard_normal = create_standard_normal()
print(type(normal))
# <class 'aesara.tensor.random.basic.NormalRV'>
print(type(standard_normal))
# <class 'aesara.tensor.random.basic.NormalRV'>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it work with RandomStream like that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it will, RandomStream
checks that the attribute is both in aesara.tensor.random.basic
and is an instance of RandomVariable
. We'll soon know for sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This quickly gets complicated, monkey patching affects the class globally so i would need to replace the method of the instance directly.
By the way, I've noticed that NumPy's Generator
has a a method for every RV
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can add a canonical rewrite that replaces the standard versions immediately?
Or can we tweak RandomStream perhaps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can add a canonical rewrite that replaces the standard versions immediately?
That would certainly work. But I'd rather not rely on rewrites to "fix" a problem in the representation of objects in the IR, which is also my concern in #1213. It would be better if they were represented by the same type in the original graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding / Replacing most methods in class instances works fine, for instance:
import aesara.tensor as at
from aesara.tensor.random.basic import t, StudentTRV
def create_standard_t():
C = StudentTRV()
def new_call(self, df, size=None, **kwargs):
return self.__call__(df, 0, 1, size, **kwargs)
C.call = new_call.__get__(C)
return C
standard_t = create_standard_t()
print(standard_t.call(2.).owner.inputs)
# [RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F618591A420>), TensorConstant{[]}, TensorConstant{11}, TensorConstant{2.0}, TensorConstant{0}, TensorConstant{1}]
However, special methods like __call__
are looked up with respect to the class of the object, and not its instance. So if we monkey patch the method it will end up affecting all the instances of the class, and thus instances of the non-standard RandomVariable
. So I think that's a dead end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think we should be able to update RandomStream
to handle these cases without too much difficulty and/or compromise.
ad83911
to
252a4b5
Compare
Removed the commits related to |
Numpy only defines
standard_t
(it has inconsistent naming with the location-scale library), so I defined it as aScipyRandomVariables
instead. I followed what was done for the other members of the location-scale family.Here are a few important guidelines and requirements to check before your PR can be merged:
pre-commit
is installed and set up.Ticks one off of #1093