-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathget_data_synthetic.py
75 lines (54 loc) · 2.09 KB
/
get_data_synthetic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn as nn
import torch.distributions as tdist
import matplotlib.pyplot as plt
import math
import sys
import pickle
from dro_mev_functions.DRO_MEV_nn import *
from dro_mev_functions.DRO_MEV_train import *
from dro_mev_functions.DRO_MEV_util import *
if __name__ == '__main__':
if len(sys.argv) == 1:
raise Exception("Must input either: 'asl' or 'sl' ")
data_gen_type = sys.argv[1]
N = 10000
nsd = stdfNSD(torch.tensor([3.0, 3.0]), torch.tensor(2.1))
x = nsd.M
plt.scatter(x[:,0], x[:,1])
plt.savefig('nsd.pdf')
if data_gen_type == "asl":
########## Asymmetric Logistic ##############
print('=====> Generating Asymmetric Logistic Mixture')
rates = torch.tensor([0.0001, 0.9])
d = 2
alpha = 0.8 * torch.ones(1)
alphas = torch.tensor((0.5 * torch.ones_like(alpha), alpha))
thetas = torch.rand(d)
thetas = torch.stack((thetas, 1 - thetas), dim=0)
as1 = AsymmetricLogisticCopula(alphas, thetas)
alpha2 = 0.1 * torch.ones(1)
alphas2 = torch.tensor((0.5 * torch.ones_like(alpha), alpha))
thetas2 = 0.01 * torch.rand(d)
thetas2 = torch.stack((thetas2, 1 - thetas2), dim=0)
as2 = AsymmetricLogisticCopula(alphas2, thetas2)
dists = [as1, as2]
probs = torch.tensor([0.90, 0.1])
s = sample_mixture(dists, probs, rates, N, 100)
print(s)
with open('mixture_asl_asl.p','wb') as f:
pickle.dump(s, f)
print('Max:{}'.format(s.max().item()))
else:
########## Symmetric Logistic ##############
print('=====> Generating Symmetric Logistic Mixture')
rates = torch.tensor([0.01, 0.8])
s1 = SymmetricLogisticCopula(2, 0.95)
s2 = SymmetricLogisticCopula(2, 0)
dists = [s1, s2]
probs = torch.tensor([0.95, 0.05])
s = sample_mixture(dists, probs, rates, N, 100)
with open('mixture_sl_sl.p','wb') as f:
import pickle
pickle.dump(s, f)
print('Max:{}'.format(s.max().item()))