-
Notifications
You must be signed in to change notification settings - Fork 0
/
SSA_DoF.py
78 lines (72 loc) · 3.24 KB
/
SSA_DoF.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# !/usr/bin/python2.7
# coding=utf-8
# *************************
# decode and forward scheme for comparison in G-SSA-PNC response
# identity matrix is used as beamforming matrix in users
# full decoding is employed in BSs
# *************************
import matplotlib.pyplot as pyplot
from multiprocessing import Pool
import numpy as np
import time
from SSA_fading_channel_model import chanMatrix
def rate_DoF(M, N, K, J, H, SNR):
sum_rate = 0
SINR = np.eye(M)
for j in range(J):
H_jj = H[j * M:(j + 1) * M][:, j * N:(j + 1) * N]
K_list = range(K)
K_list.remove(j)
for k in K_list:
SINR = SINR + (np.sqrt(SNR/N))**2 * np.dot( np.dot(H[j * M:(j + 1) * M][:, k * N:(k + 1) * N], np.ones((M,1))), np.transpose(np.dot(H[j * M:(j + 1) * M][:, k * N:(k + 1) * N], np.ones((M,1)))))
# sign, logdet = np.linalg.slogdet(np.ones((M, M)) + np.sqrt( SNR/N)**2 * np.dot(np.dot(np.transpose(H_jj), np.linalg.inv(SINR)), H_jj) )
sign, logdet = np.linalg.slogdet(np.ones((1, 1)) + np.sqrt( SNR/N)**2 * np.dot(np.dot(np.transpose(np.dot(H_jj, np.ones((M,1)))), np.linalg.inv(SINR)), np.dot(H_jj, np.ones((M,1)))) )
if sign == 1:
sum_rate = sum_rate + 0.5 * logdet
elif sign == 0:
continue
elif sign == -1:
continue
else:
raise Exception('Error occurs in logdet computation!!!')
return max(sum_rate, 0)
if __name__ == "__main__":
M = 3
N = 3
K = 2
J = 2
# SNR = [10 ** 1,10 ** 1.25, 10 ** 1.5, 10 ** 1.75, 10 ** 2, 10 ** 2.25, 10 ** 2.5, 10 ** 2.75, 10 ** 3,10 ** 3.25, 10 ** 3.5]
# SNR = [10**1, 10**2, 10**2.5, 10**3, 10**3.5, 1e4, 10**4.5, 10**5]
# SNR = [10**1, 10**2, 10**2.5, 10**3, 10**3.5, 1e4, 10**4.5, 10**5, 10**5.5, 10**6, 10**6.5, 1e7, 10**7.5, 10**8]
SNR = [10 ** 0.5, 10 ** 0.75]
iter = 2000
p = Pool(20)
dof_sum_rate_list = []
for snr in SNR:
dof_sum_rate = 0
dof_rate = 0
multiple_res = []
for i in range(iter):
H_gaus = np.random.randn(K * M, K * N)
# H_gaus = chanMatrix(M, N, K, J)
# H_gaus = np.eye(K * M)
res = p.apply_async(rate_DoF, (M, N, K, J, H_gaus, snr))
# dof_rate = rate_DoF(M, N, K, J, H_gaus, snr)
# dof_sum_rate = dof_sum_rate + dof_rate
multiple_res.append(res)
for res in multiple_res:
dof_rate = res.get()
dof_sum_rate = dof_sum_rate + dof_rate
dof_sum_rate = dof_sum_rate / iter
dof_sum_rate_list.append(dof_sum_rate)
Full_Result = np.column_stack((10 * np.log10(SNR), dof_sum_rate_list))
np.savetxt(
'/home/haizi/PycharmProjects/SSA/Simu_result/' + 'DoF' + ' K=' + K.__str__() + ' iter =' + iter.__str__() + time.ctime() + 'Simu_Data.txt',
Full_Result, fmt='%1.5e')
# print 'sum rate of Decode and forward scheme:', dof_sum_rate
pyplot.plot(10 * np.log10(SNR), dof_sum_rate_list, 'b*-', label='DoF')
pyplot.xlabel('SNR/dB')
pyplot.ylabel('Sum Rate/bps')
pyplot.legend(loc='upper left')
pyplot.savefig('/home/haizi/PycharmProjects/SSA/Simu_result/' +'DoF'+ ' K=' + K.__str__() + ' iter =' + iter.__str__() + time.ctime() + 'fig', format = 'eps')
pyplot.show()