-
Notifications
You must be signed in to change notification settings - Fork 11
/
mixmodels.py
executable file
·100 lines (75 loc) · 2.44 KB
/
mixmodels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/python3
import sys
import os
import argparse
import traceback
import random
import math
import time
import logging
import h5py
import contextlib
import tensorflow as tf
import numpy as np
import data
from board import Board
#Command and args-------------------------------------------------------------------
description = """
Linearly average different neural net weights.
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument('-model-file', help='model file prefix to load', required=True, action='append')
parser.add_argument('-output-file', help='model file prefix to write to', required=True)
args = vars(parser.parse_args())
model_files = args["model_file"]
output_file = args["output_file"]
def log(s):
print(s,flush=True)
# Model ----------------------------------------------------------------
print("Building model", flush=True)
import model
def volume(variable):
shape = variable.get_shape()
variable_parameters = 1
for dim in shape:
variable_parameters *= dim.value
return variable_parameters
variables = {}
total_parameters = 0
for variable in tf.global_variables():
variable_parameters = volume(variable)
total_parameters += variable_parameters
variables[variable.name] = variable
log("Global variable %s, %d parameters" % (variable.name,variable_parameters))
log("Built model, %d total parameters" % total_parameters)
# Testing ------------------------------------------------------------
print("Testing", flush=True)
saver = tf.train.Saver(
max_to_keep = 10000,
save_relative_paths = True,
)
count = 0
accum_weights = {}
tfconfig = tf.ConfigProto(log_device_placement=False)
with tf.Session(config=tfconfig) as session:
for model_file in model_files:
saver.restore(session, model_file)
def run(fetches):
return session.run(fetches, feed_dict={})
print("Processing: " + model_file)
count += 1
for name in variables:
weights = np.array(run(variables[name]))
if name in accum_weights:
accum_weights[name] = accum_weights[name] + weights
else:
accum_weights[name] = weights
print("Normalizing...")
for name in accum_weights:
accum_weights[name] = accum_weights[name] / count
assign_ops = dict([(name,variables[name].assign(accum_weights[name])) for name in accum_weights])
with tf.Session(config=tfconfig) as session:
session.run(assign_ops)
print("Saving to " + output_file)
saver.save(session, output_file)
print("Done")