-
Notifications
You must be signed in to change notification settings - Fork 0
/
ensemble_tests.py
141 lines (104 loc) · 12.8 KB
/
ensemble_tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import multiprocessing
from multiprocessing import Pool, Queue
from time import time
import gym
from atari_wrappers import WarpFrame
import random
from vae import VAE
from copy import deepcopy
from operator import attrgetter
from model import Model, empty_class
from ga import wrap_env, W, H, distribute
from merger import merge_all
def test():
list_models = [
'./saved/0254result_578.2577006758438',
#'./saved/0297result_372.95431161691926',
#'./saved/0236result_643.5598879468932',
#'./saved/0081result_422.07389722683615',
#'./saved/0177result_421.18715488163565',
#'./saved/0198result_466.7534881787524',
#'./saved/0179result_431.7866706964803',
#'./saved/0293result_285.57063288357404',
#'./saved/0124result_503.35098973857504',
'./saved/0301result_395.1973162218356',
]
#'./saved/0060result_425.47869352868946',
# './saved/0072result_211.4064480425836',
#'./saved/0088result_343.65973280883276' #out: 12
b_models = [Model(output_size=3, no_rew_early_stop=100000) for m in list_models]
for i in range(len(b_models)):
b_models[i].load_weights(list_models[i]).expect_partial()
results_queue = Queue()
training_queue = Queue()
pool = Pool(4, distribute, (training_queue, results_queue,))
n=50
for i, model in enumerate(b_models):
info = {'weights': model.get_weights(), 'attr': model.get_pickle_obj(), 'id': i}
for _ in range(n):
training_queue.put(info)
results = []
while len(results) < len(b_models)*n:
result = results_queue.get()
results.append(result)
for result in results:
b_models[result['id']].add_results(result['result'])
for i, model in enumerate(b_models):
print('list_models', list_models[i])
print(model.results[str(model.generation)])
print('avg:', np.mean(model.results[str(model.generation)]))
print('std', np.std(model.results[str(model.generation)]))
class Weights:
def __init__(self):
weights_forms = [{'shape': (128, 3), 'count': 2}, {'shape': (256, 128), 'count': 2}]
if __name__=='__main__':
multiprocessing.set_start_method('spawn') # or hang
test()
'''
list_models ./saved/0254result_578.2577006758438
ListWrapper([822.8070175438443, 410.0671140939598, 416.0142348754319, 293.83561643835935, 784.4765342960193, 667.12328767122, 625.3521126760429, 352.76872964168274, 500.68259385663964, 596.3696369636888, 701.9801980197918, 532.1070234113587, 594.9152542372748, 383.2214765100552, 874.3589743589629, 708.7248322147518, 501.4760147601361, 640.350877192968, 614.2857142857007, 687.1621621621475, 536.015325670484, 761.6600790513731, 316.10738255033783, 367.8899082568708, 670.072992700716, 568.7898089171817, 466.0377358490424, 360.0550964187224, 683.4645669291216, 662.5899280575389, 537.4622356495332, 258.30618892506993, 902.8999999999916, 611.9205298013144, 578.7003610108175, 866.3299663299512, 828.825622775789, 701.498127340812, 553.7216828478837, 389.65517241377984, 477.31958762885364, 601.9230769230666, 676.0736196318928, 627.2727272727138, 693.706293706282, 491.8367346938666, 697.1014492753554, 790.5660377358373, 715.1815181518065, 409.91501416429577])
avg: 590.2195635178067
std 159.63372450910683
list_models ./saved/0297result_372.95431161691926
ListWrapper([765.6126482213335, 821.2328767123122, 363.02250803857277, 493.6599423630996, 664.4927536231776, 780.4347826086852, 444.8504983388563, 508.856088560874, 653.9936102236306, 716.0535117056729, 904.3999999999909, 908.1999999999912, 552.1739130434676, 545.2702702702576, 710.8108108107947, 859.4095940959315, 558.4615384615299, 562.8242074927851, 237.46130030959148, 653.9432176656036, 570.0336700336562, 643.5064935064806, 423.8095238095132, 629.9999999999882, 299.99999999998766, 331.15942028984216, 562.6139817629077, 765.1315789473535, 906.5999999999918, 646.7532467532341, 357.79220779219645, 606.4846416382117, 370.9677419354725, 802.8776978417113, 491.7602996254561, 115.43408360129033, 680.7308970099562, 746.4052287581546, 561.0738255033444, 630.6397306397172, 869.6969696969544, 741.1552346570301, 498.6622073578484, 279.06137184114215, 685.9778597785856, 524.9999999999968, 547.2602739725913, 443.62416107381495, 442.37288135592246, 85.89743589744057])
avg: 579.352934752519
std 195.1614295426738
list_models ./saved/0236result_643.5598879468932
ListWrapper([610.2272727272617, 672.5752508361071, 654.9668874172111, 662.9629629629495, 719.7278911564548, 628.4768211920446, 614.7540983606472, 796.428571428557, 640.2135231316639, 611.1111111111003, 749.462365591386, 431.84713375794854, 615.6249999999912, 685.714285714273, 641.9354838709586, 567.8200692041447, 530.3630363036206, 781.4814814814683, 618.6311787072157, 796.6789667896585, 670.4402515723127, 771.1864406779533, 834.8659003831326, 499.9999999999882, 353.03867403314234, 366.8769716088209, 635.9154929577342, 657.2815533980466, 700.7246376811502, 655.2447552447464, 697.8723404255232, 655.0335570469689, 824.6031746031646, 407.75193798449595, 602.1276595744606, 285.71428571428163, 702.1582733812817, 677.4086378737444, 693.8144329896767, 797.058823529396, 485.5855855855787, 551.8771331057926, 570.5539358600519, 324.05063291138197, 676.595744680842, 306.59340659339404, 458.06451612902345, 762.8158844765255, 642.3728813559217, 757.1428571428429])
avg: 621.1154754053208
std 136.85578781916342
list_models ./saved/0081result_422.07389722683615
ListWrapper([264.2857142857007, 896.2686567164006, 357.62711864405355, 861.9771863117694, 833.5664335664188, 278.20512820511647, 415.0501672240672, 562.8787878787747, 669.4915254237152, 168.96551724138388, 123.79603399433894, 878.1818181818023, 540.9395973154262, 796.3210702340981, 332.52595155709196, 388.09523809523887, 514.8409893992816, 703.3472803347127, 487.096774193539, 614.7651006711292, 296.1661341852916, 454.7445255474321, 522.8956228956094, 396.6442953020034, 884.9056603773455, 664.3097643097494, 119.51219512195635, -13.35740072202166, 498.5915492957628, 877.7777777777627, 680.2547770700476, 851.8072289156494, 686.1842105263026, 694.6127946127792, 332.34323432342063, 237.31343283581404, 599.6047430829918, 661.7449664429388, 544.9275362318795, 699.9999999999844, 860.7142857142708, 613.7546468401339, 709.187279151931, 695.3020134228049, 392.7007299269954, 580.7692307692164, 221.31147540983332, 428.05280528051827, 702.2813688212757, 687.0967741935384])
avg: 545.4075949427456
std 229.21599190928256
list_models ./saved/0177result_421.18715488163565
ListWrapper([475.6578947368342, 199.65156794425565, 258.7301587301493, 502.8368794326165, 312.58741258741395, 384.53608247422153, 514.6496815286516, 450.33557046979325, 755.3719008264359, 493.8461538461473, 522.7758007117351, 269.9999999999913, 627.5985663082315, 566.666666666656, 575.862068965503, 537.931034482749, 492.0577617328408, 288.8888888888846, 141.00719424460934, 432.37410071942105, 398.50746268656417, 351.61290322580487, 539.534883720924, 629.5081967212968, 518.5567010309207, 489.2857142857024, 502.78745644598297, 616.3120567375781, 460.2240896358499, 423.5109717868294, 413.33333333332666, 458.8235294117555, 222.9813664596321, 431.8352059925046, 523.5294117646984, 541.4342629481976, 445.7746478873187, 470.4467353951819, 483.0618892508091, 391.96141479098736, 435.4609929077942, 447.2972972972875, 358.46153846153277, 442.8571428571327, 589.5306859205692, 495.9595959595893, 546.8531468531404, 407.09219858155706, 183.07692307692744, 504.89510489509746])
avg: 450.5574449123927
std 121.0688831495816
list_models ./saved/0198result_466.7534881787524
ListWrapper([723.5294117646919, 832.4324324324218, 500.70671378091197, 608.9552238805821, 338.3116883116874, 541.3793103448172, 502.7397260273907, 621.3740458015188, 565.5052264808252, 624.6835443037879, 399.9999999999978, 706.7226890756222, 613.2075471698026, 610.3448275861941, 545.6692913385763, 545.390070921975, 584.9816849816756, 439.285714285708, 324.1379310344755, 548.8549618320554, 561.9718309859063, 392.0634920634895, 549.9999999999903, 549.9999999999899, 296.7213114753999, 538.0952380952291, 367.1052631578843, 447.10144927535504, 242.7672955974799, 640.2135231316615, 379.16666666666686, 313.6690647481874, 550.8474576271082, 359.45945945945965, 354.54545454544467, 504.2944785275986, 422.12389380530715, 506.42570281123267, 184.1328413284178, 509.9290780141758, 636.6412213740365, 311.56462585033233, 79.85611510791843, 408.2508250825041, 559.0909090909006, 424.2718446601876, 279.18215613383296, 570.6827309236877, 254.46685878963004, 555.9485530546524])
avg: 478.5760276547677
std 146.41062665558937
list_models ./saved/0179result_431.7866706964803
ListWrapper([532.3024054982733, 413.67781155014296, 406.493506493498, 645.3183520599154, 320.68965517240423, 630.7692307692203, 536.0294117646961, 665.5677655677549, 656.5543071160952, 730.3249097472841, 599.9999999999898, 267.0033670033554, 569.0391459074654, 276.5822784810048, 234.310850439885, 522.710622710612, 527.009646302242, 529.9999999999916, 579.3103448275765, 583.3333333333223, 430.0353356890458, 608.029197080282, 573.2283464566816, 458.3596214510971, 462.87425149699845, 543.3121019108174, 277.5510204081589, 478.9473684210397, 583.1683168316748, 192.2535211267658, 561.9217081850446, 376.66666666666356, 619.3675889327985, 716.479400749054, 759.3749999999873, 382.5737265415486, 641.5730337078538, 638.4105960264812, 571.2802768166027, 628.9377289377186, 454.45544554454483, 520.9150326797289, 650.9025270758021, 408.9605734766914, 497.52321981423154, 497.86476868326656, 627.2727272727146, 528.5714285714173, 451.6014234875363, 564.2066420664125])
avg: 518.6729108170679
std 129.54084606074002
list_models ./saved/0293result_285.57063288357404
ListWrapper([670.3180212014026, 668.7074829931902, 744.5945945945795, 508.5409252668934, 403.3112582781337, 511.11111111109886, 684.6715328467016, 323.5668789808934, 660.655737704907, 327.6527331189578, 483.33333333332433, 789.67971530248, 686.5612648221236, 706.779661016938, 715.699658703056, 872.6562499999866, 530.3630363036197, 647.3684210526172, 634.824281150149, 464.1025641025554, 678.9855072463665, 272.82229965156523, 520.5787781350361, 693.4426229508083, 682.6086956521586, 702.013422818778, 681.4814814814674, 486.2068965517175, 286.0182370820573, 209.2783505154682, 274.1007194244476, 733.3333333333223, 524.59546925565, 538.3763837638262, 590.9090909090818, 587.2964169380971, 771.595330739288, 747.4576271186309, 675.8620689655017, 404.1322314049455, 481.7490494296453, 681.3620071684464, 673.2342007434818, 480.85808580856997, 518.5897435897411, 613.3757961783275, 287.8787878787913, 331.48688046646373, 790.510948905096, 863.9999999999866])
avg: 576.3727785198074
std 165.52501001346576
list_models ./saved/0124result_503.35098973857504
ListWrapper([435.7142857142713, 737.3702422145233, 358.1818181818039, 485.6697819314529, 808.771929824545, 484.2293906809941, 621.0884353741378, 706.2499999999882, 298.0891719745101, 664.9253731343122, 142.81150159744874, 391.9999999999876, 792.8571428571282, 667.1232876712181, 426.31578947366796, 640.0611620794975, 250.99337748344823, 657.1884984025464, 289.83050847456406, 590.647482014374, 255.4817275747372, 674.7440273037383, 389.1640866872953, 63.33333333333818, 853.0201342281725, 608.0536912751556, 555.7971014492622, 438.9610389610292, 159.3856655290153, 222.8070175438594, 645.7627118643937, 529.3706293706175, 543.598615916946, 842.4460431654514, 593.3797909407523, 578.2006920415122, 658.8424437298921, 377.5641025640928, 436.1842105263019, 573.2026143790716, 624.3589743589641, 765.3198653198496, 60.40955631399794, 320.2453987729949, 463.706563706548, 443.3333333333193, 351.2195121951103, 450.488599348521, 752.5179856114964, 398.4326018808697])
avg: 501.5890249662145
std 198.40220947932858
list_models ./saved/0301result_395.1973162218356
ListWrapper([270.62937062936265, 139.28571428571934, 711.8081180811695, 607.317073170718, 684.3749999999889, 826.8292682926754, 16.279069767443204, 579.3002915451827, 527.5167785234853, 554.4117647058713, 658.7412587412471, 460.40268456374855, 582.5938566552736, 683.6879432624016, 849.6124031007672, 880.3921568627283, 325.6055363321688, 494.02985074626173, 716.0535117056726, 566.6666666666541, 796.4285714285563, 611.3402061855539, 729.3515358361626, 351.92307692306827, 546.1038961038855, 392.58160237387705, 370.58823529410535, 690.7801418439617, 766.6666666666524, 849.6124031007678, 482.1917808219074, 757.6388888888746, 654.4483985765019, 373.3542319749102, 823.6363636363486, 595.5128205128121, 657.7854671280196, 474.0740740740604, 650.7692307692222, 778.0487804877929, 338.3116883116801, 778.0487804877977, -28.80258899676456, 823.3576642335613, 721.6783216783106, 717.880794701978, 820.8333333333173, 299.3902439024292, 532.5878594249097, 643.1506849314954])
avg: 582.6962294454858
std 211.08311662675658
'''