Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update distributions #205

Open
wants to merge 53 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
75dd5a0
Merge pull request #1 from Jittor/master
Exusial Feb 4, 2021
2db7b94
fix bicubic,add fold.
csaucl4 Feb 8, 2021
40a915d
add eye.
csaucl4 Feb 11, 2021
fa2f911
add test for qr,bicubic,fold,unfold.
csaucl4 Feb 13, 2021
f015999
fix bicubic and fold to code_op ver.add grad test.
csaucl4 Feb 23, 2021
6f2c22f
Merge branch 'master' into master
Exusial Feb 23, 2021
cc7c0bb
add docs.update pinv to support (..,M,N) shape
csaucl4 Feb 23, 2021
bfe54df
Merge branch 'master' into master
Exusial Feb 26, 2021
9c722a9
edit maintainer and testfunc's name.
csaucl4 Feb 26, 2021
cb705f7
fix nn
Gword Mar 2, 2021
e5d6262
fix nn
Gword Mar 2, 2021
ce24132
fix conflict.
csaucl4 Mar 18, 2021
df14d21
Merge branch 'ej'
csaucl4 Mar 18, 2021
a15cb63
add sampler and subset.
csaucl4 Mar 18, 2021
5c6c160
fix subset,change jittor op to np op.
csaucl4 Mar 19, 2021
37d723a
fix.
csaucl4 Mar 20, 2021
9452f8b
fix space.
csaucl4 Mar 20, 2021
2bc9ee1
fix?.
csaucl4 Mar 20, 2021
9860581
copy.
csaucl4 Mar 20, 2021
9711e85
copy.
csaucl4 Mar 20, 2021
1d76ae2
delete .idea
csaucl4 Mar 22, 2021
fa62b30
add sampler hook in dataset
cjld Mar 23, 2021
74a93e7
add doc
cjld Mar 23, 2021
4471771
Merge branch 'master' of https://github.com/Jittor/jittor into add_sa…
cjld Mar 23, 2021
a20a1bd
add sampler
cjld Mar 23, 2021
0b9c927
add onednn install+support.pass mkl test.
csaucl4 Mar 25, 2021
d41b718
Merge branch 'master' of https://github.com/Exusial/jittor
csaucl4 Mar 25, 2021
b9acc79
merge.
csaucl4 Mar 25, 2021
3c36d9c
Merge branch 'Jittor-master'
csaucl4 Mar 25, 2021
b001f87
add opencl.
csaucl4 Apr 16, 2021
bebd3d4
Merge pull request #5 from Jittor/master
Exusial Apr 16, 2021
ce0829b
Merge pull request #6 from Jittor/master
Exusial May 7, 2021
e390d1d
update extern,dist.
Exusial May 7, 2021
2a483cf
update extern,dist.
Exusial May 7, 2021
d4f83b0
fix.
Exusial May 7, 2021
51a57ab
fix.
Exusial May 7, 2021
db6a585
delete useless.
Exusial May 7, 2021
9326396
add test.
Exusial May 7, 2021
5f3069a
merge.
Exusial May 7, 2021
3d2ff97
Merge branch 'Jittor-master'
Exusial May 7, 2021
c634741
merge.
Exusial May 7, 2021
62e6f81
fix.
Exusial May 8, 2021
1b87e73
add finfo.
Exusial May 10, 2021
1ef727e
add gemotric
Exusial May 10, 2021
57c18e7
fix onehot.
Exusial May 11, 2021
94868a1
fix.
Exusial May 12, 2021
c7e0e44
fix.
Exusial May 12, 2021
ac8c855
merge.
Exusial May 12, 2021
934dbbb
Merge branch 'Jittor-master'
Exusial May 12, 2021
6ff9597
fix test.
Exusial May 12, 2021
3feaea4
update.
Exusial May 28, 2021
1656d74
update.
Exusial May 28, 2021
57a9fed
Merge branch 'Jd'
Exusial May 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/jittor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,4 +1270,4 @@ def get_len(var):
from . import numpy2cupy
from .contrib import concat
from .misc import *
from . import sparse
from . import sparse
101 changes: 97 additions & 4 deletions python/jittor/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ def simple_presum(x):
cpu_src=src, cuda_src=src)


def lgamma(x):
header = '''#include<cmath>'''
src = '''
@alias(a, in0)
@alias(b, out0)
for (int i=0;i<a_shape0;i++){
@b(i) = lgamma(@a(i));
}
'''
return jt.code([x.shape],[x.dtype],[x],cpu_header=header,cpu_src=src)[0]


class OneHotCategorical:
def __init__(self, probs=None, logits=None):
assert not (probs is None and logits is None)
Expand Down Expand Up @@ -99,7 +111,7 @@ def __init__(self, mu, sigma):
self.mu = mu
self.sigma = sigma

def sample(self, sample_shape=None):
def sample(self,sample_shape=None):
return jt.normal(jt.array(self.mu), jt.array(self.sigma),size=sample_shape)

def log_prob(self, x):
Expand All @@ -109,6 +121,85 @@ def log_prob(self, x):

def entropy(self):
return 0.5+0.5*np.log(2*np.pi)+jt.log(self.sigma)

def cdf(self, x):
return 0.5 * (1 + jt.erf((x - self.loc) * (1. / self.mu) / math.sqrt(2)))


class Uniform:
def __init__(self,low,high):
self.low = low
self.high = high
assert high > low

def sample(self,sample_shape):
return jt.uniform(self.low,self.high,sample_shape)

def log_prob(self,x):
if x < self.low or x >= self.high:
return math.inf
return -jt.log(self.high - self.low)

def entropy(self):
return jt.log(self.high - self.low)

def cdf(self, x):
return jt.clamp((x-self.low)/(self.high - self.low),min_v=0,max_v=1)


class Geometric:
def __init__(self,p=None,logits=None):
assert (p is not None) or (logits is not None)
assert 0 < p and p < 1
if p is None:
self.prob = jt.sigmoid(logits)
self.logits = logits
elif logits is None:
self.prob = p
self.logits = -jt.log(1. / p - 1)

def sample(self,sample_shape):
tiny = jt.info(self.probs.dtype).tiny
u = jt.clamp(jt.rand(sample_shape),min_v=tiny)
return (jt.log(u) / (jt.log(-self.probs+1))).floor()

def log_prob(self,x):
return x*jt.log(-self.prob+1)+jt.log(self.prob)

def entropy(self):
return binary_cross_entropy_with_logits(jt.array(self.logits),jt.array(self.prob)) / self.prob


def Poisson_sample(la, size):
p = math.exp(-la)
u = jt.random(size, "float32")
res = jt.zeros(size)
for i in size:
k = 0
p = math.exp(-la)
s = p
if u[i] <= p:
res[i] = 0
continue
else:
while u > s:
p = la * p / (k + 1)
s = s + p
k += 1
res[i] = k
return res


class Poisson:
def __init__(self, la):
self.la = la

def sample(self, sample_shape):
return Poisson_sample(self.la,sample_shape)

def log_prob(self,x):
# todo: add lgamma.
return jt.log(self.la)* x - self.la - lgamma(x + 1)


class Uniform:
Expand Down Expand Up @@ -158,15 +249,17 @@ def kl_divergence(cur_dist, old_dist):
vr = (cur_dist.sigma / old_dist.sigma)**2
t1 = ((cur_dist.mu - old_dist.mu) / old_dist.sigma)**2
return 0.5*(vr+t1-1-jt.log(vr))
if isinstance(cur_dist, Categorical) or isinstance(cur_dist,OneHotCategorical):
if isinstance(cur_dist,Categorical) or isinstance(cur_dist,OneHotCategorical):
t = cur_dist.probs * (cur_dist.logits-old_dist.logits)
t[jt.array((old_dist.probs == 0))] = math.inf
t[jt.array((cur_dist.probs == 0))] = 0
return t.sum(-1)
if isinstance(cur_dist, Uniform):
if isinstance(cur_dist,Uniform):
res = jt.log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))
if old_dist.low > cur_dist.low or old_dist.high < cur_dist.high:
res = math.inf
return res
if isinstance(cur_dist, Geometric):
if isinstance(cur_dist,Geometric):
return -cur_dist.entropy() - jt.log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits
if isinstance(cur_dist,Poisson):
return cur_dist.la * (jt.log(cur_dist.la) - jt.log(old_dist.la)) - (cur_dist.la - old_dist.la)
14 changes: 9 additions & 5 deletions python/jittor/test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ def test_presum(self):
a = jt.array([[1,2,3,4]])
b = jd.simple_presum(a)
assert (b.data == [[0,1,3,6,10]]).all()

def test_lgamma(self):
import torch
ta = np.random.uniform(2,3,(1))
a = jt.array(ta).float32()
assert np.allclose(jd.lgamma(a).data, torch.lgamma(torch.tensor(ta)).numpy()),(jd.lgamma(a).data, torch.lgamma(torch.tensor(ta)).numpy())

def test_one_hot(self):
a = jd.OneHotCategorical(jt.array([0.25, 0.25, 0.25, 0.25]))
Expand All @@ -31,9 +37,9 @@ def test_one_hot(self):
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
import torch
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs).to(torch.float32)),torch.distributions.OneHotCategorical(torch.tensor(probs2).to(torch.float32))
jc, jc2 = jd.OneHotCategorical(jt.array(probs).reshape(1,-1)),jd.OneHotCategorical(jt.array(probs2).reshape(1,-1))
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
assert np.allclose(jc.entropy().data,tc.entropy().numpy())
assert np.allclose(jc.entropy().data,tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
x = np.zeros((4,10))
for _ in range(4):
nx = np.random.randint(0,9)
Expand Down Expand Up @@ -72,7 +78,6 @@ def test_categorical(self):
for _ in range(4):
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
jc, jc2 = jd.Categorical(jt.array(probs).reshape(1,-1)),jd.Categorical(jt.array(probs2).reshape(1,-1))
tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2))
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
x = np.random.randint(0,10,(4))
Expand Down Expand Up @@ -101,8 +106,7 @@ def test_geometric(self):
assert np.allclose(jg.entropy().data,tg.entropy().numpy())
x = np.random.randint(1,10)
assert np.allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x)))
# print(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2))
assert np.allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2))

if __name__ == "__main__":
unittest.main()
unittest.main()