Skip to content

Commit

Permalink
Enforce: 2 and 4 dims, remove information about out in format
Browse files Browse the repository at this point in the history
  • Loading branch information
mozga-intel committed Apr 3, 2018
1 parent 32f8ac7 commit 46e14bb
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 18 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/operators/fc_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,10 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto input = ctx.Input<Tensor>("Input");
auto w = ctx.Input<Tensor>("W");

PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 2,
PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4,
"Input must be with 2 or 4 dimensions, i.e. NCHW");
PADDLE_ENFORCE(w->dims().size() == 2,
"Weights must be with 2 dimensions, i.e. NC");
PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4,
"Weights must be with 2 or 4 dimensions, i.e. OI or OIHW");

bool with_bias = ctx.Attr<bool>("bias_attr");
MKLDNNMD<Tensor> md(input, w, with_bias);
Expand Down
21 changes: 6 additions & 15 deletions paddle/fluid/operators/fc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fc_op.h"
#include <vector>

namespace paddle {
namespace operators {
Expand All @@ -29,11 +30,11 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
auto w_dims = ctx->GetInputDim("W");
std::vector<int64_t> output_shape({in_dims[0], w_dims[1]});

PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 2,
PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
"Fully Connected input should be 2-D or 4-D tensor.");

PADDLE_ENFORCE(w_dims.size() == 2,
"Fully Connected input should be 2-D tensor.");
PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4,
"Fully Connected input should be 2-D or 4-D tensor.");

ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->ShareLoD("Input", "Out");
Expand Down Expand Up @@ -73,19 +74,9 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(

FCOpMaker::FCOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"Input",
"(Tensor) The input tensor of fully connected operator. "
"The format of input tensor is NCHW, where N is batch size, C is the "
"number of channels, H is the height of the feature, "
"and W is the width of the feature.");
AddInput("Input", "(Tensor) The input tensor of fully connected operator. ");
AddInput("W", "(Tensor), The second input tensor of fc op.");
AddOutput("Out",
"(Tensor) The output tensor of fully connected operator. "
"The format of output tensor is also NCHW, "
"where N is batch size, C is the number of channels, "
"H is the height of the feature, "
"and W is the width of the feature.");
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
Expand Down

0 comments on commit 46e14bb

Please sign in to comment.