From 5115d3407e99bc636ddcc500a4af594a7b117db0 Mon Sep 17 00:00:00 2001 From: LakshmiKumar23 Date: Mon, 17 May 2021 10:24:45 -0700 Subject: [PATCH 1/2] concat bug fix --- model_compiler/python/nnir.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/model_compiler/python/nnir.py b/model_compiler/python/nnir.py index 092ac5adeb..ab13b949b3 100644 --- a/model_compiler/python/nnir.py +++ b/model_compiler/python/nnir.py @@ -443,6 +443,8 @@ def updateLocals(self): shape = [input.shape[0], 0, input.shape[2], input.shape[3]] for name in node.inputs: lshape = self.tensor_shapes[name] + while(len(lshape) < 4): + lshape.append(1) if shape[0:1] + shape[2:] != lshape[0:1] + lshape[2:]: raise ValueError("concat: mismatch detected: " + node.inputs[0] + ":" + str(shape) + " " + name + ":" + str(lshape)) shape[1] = shape[1] + lshape[1] From 2fd81ce4f395a758c86ac379b9999a4dd67117c6 Mon Sep 17 00:00:00 2001 From: LakshmiKumar23 Date: Tue, 18 May 2021 22:31:43 -0700 Subject: [PATCH 2/2] concat fix - 05/18 --- model_compiler/python/nnir.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/model_compiler/python/nnir.py b/model_compiler/python/nnir.py index 665607692c..3910e46688 100644 --- a/model_compiler/python/nnir.py +++ b/model_compiler/python/nnir.py @@ -467,6 +467,8 @@ def updateLocals(self): shape = [0, input.shape[1], input.shape[2], input.shape[3]] for name in node.inputs: lshape = self.tensor_shapes[name] + while(len(lshape) < 4): + lshape.append(1) if shape[:0] + shape[1:] != lshape[:0] + lshape[1:]: raise ValueError("concat: mismatch detected: " + node.inputs[0] + ":" + str(shape) + " " + name + ":" + str(lshape)) shape[0] = shape[0] + lshape[0] @@ -483,6 +485,8 @@ def updateLocals(self): shape = [input.shape[0], input.shape[1], 0, input.shape[3]] for name in node.inputs: lshape = self.tensor_shapes[name] + while(len(lshape) < 4): + lshape.append(1) if shape[0:2] + shape[3:] != lshape[0:2] + lshape[3:]: raise ValueError("concat: mismatch detected: " + node.inputs[0] + ":" + str(shape) + " " + name + ":" + str(lshape)) shape[2] = shape[2] + lshape[2]