-
Notifications
You must be signed in to change notification settings - Fork 2
/
make_mask.py
33 lines (29 loc) · 961 Bytes
/
make_mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
from numpy.random import randint
import random
import math
def get_mask(view_num, data_size, missing_ratio):
"""
:param view_num: number of views
:param data_size: size of data
:param missing_ratio: missing ratio
:return: mask matrix
"""
assert view_num >= 2
miss_sample_num = math.floor(data_size*missing_ratio)
data_ind = [i for i in range(data_size)]
random.shuffle(data_ind)
miss_ind = data_ind[:miss_sample_num]
mask = np.ones([data_size, view_num])
for j in range(miss_sample_num):
while True:
rand_v = np.random.rand(view_num)
v_threshold = np.random.rand(1)
observed_ind = (rand_v >= v_threshold)
ind_ = ~observed_ind
rand_v[observed_ind] = 1
rand_v[ind_] = 0
if np.sum(rand_v) > 0 and np.sum(rand_v) < view_num:
break
mask[miss_ind[j]] = rand_v
return mask