From b0a7e13a84be17e2d14a866b3a0a5c3d4de542d8 Mon Sep 17 00:00:00 2001 From: Hiroshi ABE Date: Thu, 10 May 2018 12:32:29 +0900 Subject: [PATCH] fix IndexError when running main.py without data.mat --- main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index cf1be03..a8b027d 100644 --- a/main.py +++ b/main.py @@ -74,6 +74,9 @@ def train_neural_network(model_size, file_name): data_file_name = 'data.mat' if os.path.isfile(data_file_name): data = sio.loadmat(data_file_name) + for name, value in data.items(): + if name[0]=='b' and value.ndim==2 and value.shape[0]==1: + data[name] = value.reshape(-1) else: data = train_neural_network(model_size=[784, 300, 10], file_name=data_file_name) @@ -104,7 +107,7 @@ def train_neural_network(model_size, file_name): elapsed_time = time.time() - start print('tensorflow-based execution time:', elapsed_time) refined_W[:, i] = w_tf[:-1] - refined_b[0, i] = w_tf[-1] + refined_b[i] = w_tf[-1] total_time += elapsed_time