Skip to content

Commit

Permalink
Merge pull request #307 from mil-tokyo/placeholder
Browse files Browse the repository at this point in the history
Support dynamic hyper parameters
  • Loading branch information
Kiikurage authored Jun 22, 2017
2 parents 51d132d + 1e32994 commit 2967532
Show file tree
Hide file tree
Showing 151 changed files with 5,984 additions and 2,659 deletions.
15 changes: 9 additions & 6 deletions bin/convert_caffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import argparse
import ast
import os
import sys
from os import path
Expand All @@ -16,6 +15,8 @@
from webdnn.backend.interface.generator import generate_descriptor
from webdnn.graph.converters.chainer import ChainerConverter
from webdnn.graph.graph import Graph
from webdnn.graph.shape import Shape
from webdnn.util import console


def parse_input_blob(args):
Expand All @@ -28,7 +29,9 @@ def parse_input_blob(args):
else:
if not args.input_shape:
raise ValueError("input_npy or input_shapes must be specified to determine input")
input_shape = ast.literal_eval(args.input_shape)
input_shape, placeholders = Shape.parse(args.input_shape)
if len(placeholders) > 0:
raise ValueError("caffe converter does not support an input with placeholder")
input_blob = chainer.Variable(np.zeros(input_shape, dtype=np.float32))
return input_blob, input_filled

Expand Down Expand Up @@ -57,10 +60,10 @@ def main():
input_blob, input_filled = parse_input_blob(args)
output_names = args.output_names.split(",")

sys.stderr.write("Loading caffe model... (usually takes several minutes)\n")
console.stderr("[convert_caffe] Loading caffe model... (usually takes several minutes)")
link = chainer.links.caffe.CaffeFunction(args.caffemodel)

sys.stderr.write("Generating feedforward graph\n")
console.stderr("[convert_caffe] Generating feedforward graph")
if chainer.__version__ >= "2.":
chainer.using_config("train", False)
output_blobs = list(
Expand All @@ -83,15 +86,15 @@ def main():
output_arrays = {output_name: output_blob.data for output_name, output_blob in zip(output_names, output_blobs)}
np.savez(path.join(output_dir, "example_output.npz"), **output_arrays)

sys.stderr.write("Generating descriptors\n")
console.stderr("[convert_caffe] Generating descriptors")
any_backend_failed = False
for backend in args.backend.split(","):
try:
graph_exec_data = generate_descriptor(backend, graph, constant_encoder_name=args.encoding)
graph_exec_data.save(output_dir)
except Exception as ex:
any_backend_failed = True
sys.stderr.write(f"Failed generating descriptor for backend {backend}: {str(ex)}\n")
console.error(f"[convert_caffe] Failed generating descriptor for backend {backend}: {str(ex)}")

if any_backend_failed:
sys.exit(1)
Expand Down
29 changes: 19 additions & 10 deletions bin/convert_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
"""

import argparse
import ast
import os
import sys
import traceback
from os import path

import h5py

from webdnn.backend.interface.generator import generate_descriptor
from webdnn.graph.converters.keras import KerasConverter
from webdnn.graph.graph import Graph
from webdnn.graph.shape import Shape
from webdnn.util import flags, console


def main():
Expand All @@ -29,8 +30,9 @@ def main():
parser.add_argument("--encoding", help="name of weight encoder")
args = parser.parse_args()

sys.stderr.write("Generating feedforward graph\n")
input_shape = ast.literal_eval(args.input_shape)
console.stderr(f"[{path.basename(__file__)}] Generating feedforward graph")

input_shape, _ = Shape.parse(args.input_shape)
input_shapes = [input_shape]
model = h5py.File(args.kerasmodel, "r")
converter = KerasConverter()
Expand All @@ -42,20 +44,27 @@ def main():
output_dir = path.join(path.dirname(args.kerasmodel), "webdnn_graph_descriptor")
os.makedirs(output_dir, exist_ok=True)

sys.stderr.write("Generating descriptors\n")
console.stderr(f"[{path.basename(__file__)}] Generating graph descriptor")

any_backend_failed = False
last_backend_exception = None
for backend in args.backend.split(","):
backends = args.backend.split(",")
for i, backend in enumerate(backends):
console.stderr(f"[{path.basename(__file__)}] Backend: {console.colorize(backend, console.Color.Cyan)}")
try:
graph_exec_data = generate_descriptor(backend, graph, constant_encoder_name=args.encoding)
graph_exec_data.save(output_dir)
except Exception as ex:
if flags.DEBUG:
raise ex

any_backend_failed = True
last_backend_exception = ex
sys.stderr.write(f"Failed generating descriptor for backend {backend}: {str(ex)}\n")
console.error(f"[{path.basename(__file__)}] Failed generating descriptor for {backend} backend")
console.stderr(traceback.format_exc())
continue

if any_backend_failed:
raise last_backend_exception
exit(1)
# raise last_backend_exception


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 2967532

Please sign in to comment.