diff --git a/mlu_op_test.proto b/mlu_op_test.proto index 26b43cf..943c48a 100755 --- a/mlu_op_test.proto +++ b/mlu_op_test.proto @@ -320,6 +320,7 @@ message Node { optional SyncBatchNormBackwardParam sync_batchnorm_backward_reduce_param = 38324; // param optional StridedSliceParam strided_slice_param = 44; // param optional ConcatParam concat_param = 54; // param + optional OpTensorParam op_tensor_param = 101; // param } @@ -433,6 +434,12 @@ enum IndicesType{ BIT16_INDICES = 1; } +enum OpTensorDesc { + OP_TENSOR_ADD = 1; + OP_TENSOR_SUB = 2; + OP_TENSOR_MUL = 3; +} + message ReduceParam { optional int32 axis = 1 [default = 0]; required ReduceTensorOp mode = 2 [default = REDUCE_TENSOR_MUL]; @@ -971,3 +978,12 @@ message StridedSliceParam { message ConcatParam { required int32 axis = 1 [default = 0]; } + +// param to call mluOpOpTensor() +message OpTensorParam { + optional float alpha1 = 1 [default = 1]; + optional float alpha2 = 2 [default = 1]; + optional float beta = 3 [default = 1]; + optional OpTensorDesc op = 4 [default = OP_TENSOR_ADD]; + optional bool input_same_addr = 5 [default = false]; +}