|
| 1 | +import numpy as np |
| 2 | +import networkx as nx |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import pdb |
| 5 | +import copy |
| 6 | +import random |
| 7 | +import sys |
| 8 | + |
| 9 | +from boltzmann_explr_opin import boltzmann_explr_opin |
| 10 | +from compute_prob import compute_prob |
| 11 | +from gen_graph import gen_graph |
| 12 | +from prune_feed import prune_feed |
| 13 | +from normalize_vector import normalize_vector |
| 14 | +from load_graph import load_graph |
| 15 | +from visualize_graph import visualize_graph |
| 16 | +from display_centralities import display_centralities |
| 17 | +from save_results_qopin import save_results_qopin |
| 18 | + |
| 19 | +def qopin_tb(sim, Params): |
| 20 | + |
| 21 | + #Initialization |
| 22 | + seed = sim['iterNum'] |
| 23 | + random.seed(seed) |
| 24 | + |
| 25 | + rememberingFactor = 1 |
| 26 | + |
| 27 | + blfBatchSize = Params['blfBatchSize'] |
| 28 | + |
| 29 | + gammVal = Params['gammaVal'] |
| 30 | + if Params['genGraphFlg'] == True: |
| 31 | + (g, numNbrDict, gnx) = gen_graph(Params) |
| 32 | + else: |
| 33 | + (g, numNbrDict, gnx) = load_graph(Params) |
| 34 | + |
| 35 | + |
| 36 | + if Params['graphName'] == 'rgg' or Params['graphName'] == 'grid2d': |
| 37 | + pos = nx.get_node_attributes(gnx, 'pos') |
| 38 | + else: |
| 39 | + pos = nx.spectral_layout(gnx, scale=4) |
| 40 | + #pos = nx.shell_layout(gnx, scale=2) |
| 41 | + |
| 42 | + ########################## initialization begin ####################################### |
| 43 | + # New params |
| 44 | + tauMax = Params['tauMax'] |
| 45 | + learnRate = Params['learnRate'] |
| 46 | + |
| 47 | + actionSet = {} # dict dict list |
| 48 | + |
| 49 | + # initializing src node, Q-table and action sets |
| 50 | + srcQNode = Params['srcQNode'] |
| 51 | + srcRNode = Params['srcRNode'] |
| 52 | + numNodes = Params['numNodes'] |
| 53 | + # Display centralities and degrees of the nodes |
| 54 | + [cntraltyEigVec] = display_centralities(gnx, srcQNode, srcRNode) |
| 55 | + |
| 56 | + actionSet = g |
| 57 | + |
| 58 | + Q = copy.deepcopy(actionSet) |
| 59 | + Qdel = copy.deepcopy(actionSet) |
| 60 | + cntMat = copy.deepcopy(actionSet) |
| 61 | + |
| 62 | + for indNode in range(numNodes): |
| 63 | + Q[indNode] = [0.5]*len(actionSet[indNode]) |
| 64 | + cntMat[indNode] = [0]*len(actionSet[indNode]) |
| 65 | + Qdel[indNode] = [0]*len(actionSet[indNode]) |
| 66 | + informerDict = {x:[] for x in range(numNodes)} |
| 67 | + informeeDict = {x:[] for x in range(numNodes)} |
| 68 | + opinionList = [] |
| 69 | + opinionListOppon = [] |
| 70 | + feedList = {x:[[],[],[],[]] for x in range(numNodes)} # [[Q or R message],[time of reception],[incoming node], [No. of times msg is Tx]] |
| 71 | + infNodesDumpAllRnd = {} # For future use. Collecting stats of the nodes informed at each time slot for all rounds |
| 72 | + blfMat = np.ones((numNodes, 2)) |
| 73 | + sumOpinionQ = np.array(Params['numRounds']*[0]) |
| 74 | + sumOpinionR = np.array(Params['numRounds']*[0]) |
| 75 | + |
| 76 | + #blfMatAllRnd = np.array([]) |
| 77 | + blfMatAllRnd = list() |
| 78 | + |
| 79 | + #Loop for all rounds |
| 80 | + for indRnd in range(Params['numRounds']): |
| 81 | + |
| 82 | + if np.mod(indRnd, 10) == 0: |
| 83 | + print("Round", indRnd, end=",") |
| 84 | + sys.stdout.flush() |
| 85 | + |
| 86 | + #################### initialization - for each round - begin #################### |
| 87 | + qTxNodesList = [] |
| 88 | + rTxNodesList = [] |
| 89 | + qChosenFeedList = [] |
| 90 | + rChosenFeedList = [] |
| 91 | + infNodesDumpOneRnd = [] |
| 92 | + infNodesDumpOneRnd.append([srcQNode]) |
| 93 | + |
| 94 | + # initialize state sequence |
| 95 | + stateSeq = {x:[] for x in range(numNodes)} |
| 96 | + # keep a count of the number of occurence of each state-action pair (size of count matrix = size of action set) |
| 97 | + |
| 98 | + del qTxNodesList[:] |
| 99 | + del rTxNodesList[:] |
| 100 | + del qChosenFeedList[:] |
| 101 | + del rChosenFeedList[:] |
| 102 | + feedDelDict = {x:[] for x in range(numNodes)} |
| 103 | + blfMatPrev = copy.deepcopy(blfMat) |
| 104 | + |
| 105 | + # Loop for all nodes |
| 106 | + for indNode in range(numNodes): |
| 107 | + # 1. Remove all the messages older than tauMax [TODO] |
| 108 | + feedList = prune_feed(feedList, indNode, int(indRnd/blfBatchSize), tauMax) |
| 109 | + # 2. If source node (Q or R), then |
| 110 | + if indNode == srcQNode: |
| 111 | + # collect infQ nodes |
| 112 | + qTxNodesList.append(indNode) |
| 113 | + qChosenFeedList.append(0) |
| 114 | + elif indNode == srcRNode: |
| 115 | + # collect infR nodes |
| 116 | + rTxNodesList.append(indNode) |
| 117 | + rChosenFeedList.append(0) |
| 118 | + else: |
| 119 | + # check if feedsize>0 |
| 120 | + if len(feedList[indNode][0]) != 0: |
| 121 | + # Choose a message from list w.p. exp(-tau) |
| 122 | + # Get probability vector from exp(-tau) |
| 123 | + [probFeedVec, noTx] = compute_prob(feedList[indNode][1],feedList[indNode][3],int(indRnd/blfBatchSize), Params['eta'], Params['xi']) |
| 124 | + #if len(probFeedVec)>1: |
| 125 | + # pdb.set_trace() |
| 126 | + |
| 127 | + if noTx == False: |
| 128 | + if len(probFeedVec) == 1: |
| 129 | + chosenFeedTmp = [0] |
| 130 | + else: |
| 131 | + chosenFeedTmp = np.random.choice(len(probFeedVec),1,p=probFeedVec) |
| 132 | + |
| 133 | + chosenFeed = chosenFeedTmp[0] |
| 134 | + chosenMsg = feedList[indNode][0][chosenFeed] |
| 135 | + |
| 136 | + # If no message chosen dont collect nodes otherwise collect - Does not occur now. We can consider finite buffer size [TODO] |
| 137 | + # Collect qTx nodes or collect rTx nodes |
| 138 | + if chosenMsg == 1: |
| 139 | + qTxNodesList.append(indNode) |
| 140 | + qChosenFeedList.append(chosenFeed) |
| 141 | + else: |
| 142 | + rTxNodesList.append(indNode) |
| 143 | + rChosenFeedList.append(chosenFeed) |
| 144 | + |
| 145 | + |
| 146 | + # Loop for all qTx nodes |
| 147 | + for (loopInd, indNode) in enumerate(qTxNodesList): |
| 148 | + |
| 149 | + # 2. Compute p_q (probability of transmitting message m_q) |
| 150 | + probSendMsgBlf = blfMatPrev[indNode][1]/np.sum(blfMatPrev[indNode]) |
| 151 | + # 3. sample sendMsgFlg using p_q |
| 152 | + sendMsgSamp = (np.random.rand(1,1)<=probSendMsgBlf) |
| 153 | + |
| 154 | + if indNode == srcQNode: |
| 155 | + sendMsgSamp = True |
| 156 | + # 4. If sendMsgFlg == True then |
| 157 | + # 4.a. Choose recipient (neighbor) using action-values by Boltzmann exploration rule |
| 158 | + # 4.b. Update belief of the chosen recipient |
| 159 | + # 4.c. Compute reward, update Q-table |
| 160 | + # 4.d. send msg |
| 161 | + # 4.e. delete msg |
| 162 | + if sendMsgSamp == True: |
| 163 | + # 4.a. Choose a neighbor using boltzmann exploration rule |
| 164 | + if indNode == srcQNode: |
| 165 | + incomingNodeIndex = [] |
| 166 | + else: |
| 167 | + chosenFeed = qChosenFeedList[loopInd] # location of the feed |
| 168 | + incomingNode = feedList[indNode][2][chosenFeed] # incoming node for indNode |
| 169 | + |
| 170 | + incomingNodeIndexTmp = np.where(np.array(g[indNode]) == incomingNode) |
| 171 | + incomingNodeIndex = incomingNodeIndexTmp[0][0] |
| 172 | + |
| 173 | + if np.mod(indRnd, blfBatchSize) == 0: |
| 174 | + feedList[indNode][3][chosenFeed] += 1 |
| 175 | + |
| 176 | + seedInBoltzmann = random.randint(1,100000) |
| 177 | + [action, actionIdx] = boltzmann_explr_opin(Q[indNode], incomingNodeIndex, g[indNode], Params['tempVal'], seedInBoltzmann) |
| 178 | + |
| 179 | + # 4.b. |
| 180 | + #blfMat[action][1] += 1 |
| 181 | + if np.mod(indRnd, blfBatchSize) == 0: |
| 182 | + #blfMat[action][1] = blfMat[action][1] + 1 |
| 183 | + blfMat[action][1] = blfMat[action][1]*rememberingFactor + 1 # for debug |
| 184 | + |
| 185 | + # 4.c. |
| 186 | + if Params['qLrnEn'] == True: |
| 187 | + opinion = blfMatPrev[action][1]/np.sum(blfMatPrev[action]) |
| 188 | + #rwdImm = 10*opinion*(1-opinion)/blfMatPrev[action][1] # for debug |
| 189 | + rwdImm = opinion*(1-opinion)/blfMatPrev[action][1] |
| 190 | + Qmax = np.max(Q[action]) |
| 191 | + Q[indNode][actionIdx] = (1-learnRate)*Q[indNode][actionIdx]+learnRate*(rwdImm+gammVal*Qmax) |
| 192 | + |
| 193 | + ## 4.e. Delete feed and related records from feedList of indNode |
| 194 | + if np.mod(indRnd, blfBatchSize) == 0: |
| 195 | + if indNode != srcQNode: |
| 196 | + feedDelDict[indNode].append(chosenFeed) |
| 197 | + |
| 198 | + # 4.f Append feed and related records in feedList of action node |
| 199 | + feedList[action][0].append(1) # message |
| 200 | + feedList[action][1].append(int(indRnd/blfBatchSize)) # time |
| 201 | + feedList[action][2].append(indNode) # incoming node |
| 202 | + feedList[action][3].append(0) # Number of times msg is Tx |
| 203 | + |
| 204 | + |
| 205 | + if np.mod(indRnd, blfBatchSize) == 0: |
| 206 | + # Loop for all rTx nodes |
| 207 | + for (loopInd, indNode) in enumerate(rTxNodesList): |
| 208 | + |
| 209 | + # 2. Compute p_r (probability of transmitting message m_q) |
| 210 | + probSendMsgBlf = blfMatPrev[indNode][0]/np.sum(blfMatPrev[indNode]) |
| 211 | + # 3. sample sendMsgFlg using p_q |
| 212 | + sendMsgSamp = (np.random.rand(1,1)<=probSendMsgBlf) |
| 213 | + |
| 214 | + if indNode == srcRNode: |
| 215 | + sendMsgSamp = True |
| 216 | + |
| 217 | + if sendMsgSamp == True: |
| 218 | + # 4.a. Choose a neighbor using boltzmann exploration rule |
| 219 | + if indNode == srcRNode: |
| 220 | + incomingNodeIndex = [] |
| 221 | + else: |
| 222 | + chosenFeed = rChosenFeedList[loopInd] # location of the feed |
| 223 | + incomingNode = feedList[indNode][2][chosenFeed] |
| 224 | + incomingNodeIndexTmp = np.where(np.array(g[indNode]) == incomingNode) |
| 225 | + incomingNodeIndex = incomingNodeIndexTmp[0][0] |
| 226 | + feedList[indNode][3][chosenFeed] += 1 |
| 227 | + |
| 228 | + seedInBoltzmann = random.randint(1,100000) |
| 229 | + [action, actionIdx] = boltzmann_explr_opin([1]*numNbrDict[indNode], incomingNodeIndex, g[indNode], Params['tempVal'], seedInBoltzmann) |
| 230 | + |
| 231 | + # 4.b. |
| 232 | + #blfMat[action][0] += 1 |
| 233 | + # no need to check mod(.,.). Because....refer to if |
| 234 | + # condition outside the for loop |
| 235 | + blfMat[action][0] = blfMat[action][0]*rememberingFactor + 1 # for debug |
| 236 | + |
| 237 | + if indNode != srcRNode: |
| 238 | + feedDelDict[indNode].append(chosenFeed) |
| 239 | + |
| 240 | + # 4.f Append feed and related records in feedList of action node |
| 241 | + feedList[action][0].append(0) # message |
| 242 | + feedList[action][1].append(int(indRnd/blfBatchSize)) # time |
| 243 | + feedList[action][2].append(indNode) # incoming node |
| 244 | + feedList[action][3].append(0) # Number of times msg is Tx |
| 245 | + |
| 246 | + ##deleting feeds |
| 247 | + #for indNode in range(numNodes): |
| 248 | + # if len(feedDelDict[indNode])>0: |
| 249 | + # for ind in sorted(feedDelDict[indNode], reverse=True): |
| 250 | + # del feedList[indNode][0][ind] |
| 251 | + # del feedList[indNode][1][ind] |
| 252 | + # del feedList[indNode][2][ind] |
| 253 | + # del feedList[indNode][3][ind] |
| 254 | + # del feedDelDict[indNode][:] |
| 255 | + |
| 256 | + |
| 257 | + if np.mod(indRnd, blfBatchSize) == 0: |
| 258 | + muiList = np.divide(blfMat[:,1], np.sum(blfMat, 1)) |
| 259 | + opinionList.append(np.sum(muiList)) |
| 260 | + muiOpponList = np.divide(blfMat[:,0], np.sum(blfMat, 1)) |
| 261 | + opinionListOppon.append(np.sum(muiOpponList)) |
| 262 | + |
| 263 | + #if indRnd == 0: |
| 264 | + # blfArr = np.atleast_2d(blfMat[:,0]) |
| 265 | + # blfArr = np.append(blfArr, np.atleast_2d(blfMat[:,1]), axis=1) |
| 266 | + # blfMatAllRnd = np.atleast_3d(blfArr) |
| 267 | + #else: |
| 268 | + # blfArr = np.atleast_2d(blfMat[:,0]) |
| 269 | + # blfArr = np.append(blfArr, np.atleast_2d(blfMat[:,1]), axis=1) |
| 270 | + # blfMatAllRnd = np.append(blfMatAllRnd, np.atleast_3d(blfArr), axis=2) |
| 271 | + |
| 272 | + blfMatAllRnd.append(blfMat) |
| 273 | + |
| 274 | + #if np.mod(indRnd,50*blfBatchSize) == 0 or indRnd == Params['numRounds']-1: |
| 275 | + ##if indRnd == Params['numRounds']-1: |
| 276 | + # visualize_graph(gnx, g, blfMat, pos, Params) |
| 277 | + # plt.pause(0.01) |
| 278 | + # pdb.set_trace() |
| 279 | + |
| 280 | + |
| 281 | + dumpVarNames = ['opinionList', 'opinionListOppon','blfMatAllRnd','cntraltyEigVec'] |
| 282 | + dumpVarVals = [opinionList, opinionListOppon, blfMatAllRnd, cntraltyEigVec] |
| 283 | + |
| 284 | + save_results_qopin(sim, Params, dumpVarNames, dumpVarVals) |
| 285 | + |
| 286 | + plt.figure(2) |
| 287 | + plt.plot(opinionList, linewidth=4, markersize=10, label='opinion-1') |
| 288 | + plt.plot(opinionListOppon, linewidth=4, markersize=10, label='opinion-2') |
| 289 | + plt.ylabel("Sum opinions", fontsize=24) |
| 290 | + plt.xlabel("Round", fontsize=24) |
| 291 | + |
| 292 | + if Params['qLrnEn'] == True: |
| 293 | + plt.title('with Q-learning', fontsize=24) |
| 294 | + else: |
| 295 | + plt.title('without Q-learning', fontsize=24) |
| 296 | + |
| 297 | + print(" ") |
| 298 | + plt.grid(True) |
| 299 | + plt.xticks(fontsize=24) |
| 300 | + plt.yticks(fontsize=24) |
| 301 | + plt.legend(fontsize=20) |
| 302 | + #plt.ylim((140, 360)) |
| 303 | + plt.show() |
| 304 | + plt.show(block=False) |
| 305 | + |
| 306 | + |
| 307 | + plt.figure(3) |
| 308 | + plt.hist(muiList, bins=10, label='opinion-1') |
| 309 | + plt.hist(muiOpponList, bins=10, label='opinion-2') |
| 310 | + plt.ylabel("Frequency", fontsize=24) |
| 311 | + plt.xlabel("Opinion", fontsize=24) |
| 312 | + |
| 313 | + if Params['qLrnEn'] == True: |
| 314 | + plt.title('with Q-learning', fontsize=24) |
| 315 | + else: |
| 316 | + plt.title('without Q-learning', fontsize=24) |
| 317 | + |
| 318 | + print(" ") |
| 319 | + plt.grid(True) |
| 320 | + plt.xticks(fontsize=24) |
| 321 | + plt.yticks(fontsize=24) |
| 322 | + plt.legend(fontsize=20) |
| 323 | + #plt.ylim((140, 360)) |
| 324 | + plt.show() |
| 325 | + plt.show(block=False) |
| 326 | + |
0 commit comments