Skip to content

Commit

Permalink
added avoiding BWD pd creation in inference
Browse files Browse the repository at this point in the history
  • Loading branch information
jakpiase committed Jul 1, 2021
1 parent 62f2da9 commit 5eb7bd0
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
9 changes: 5 additions & 4 deletions paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class PReluMKLDNNHandler
PReluMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* weights,
const std::string& uniq_name)
const std::string& uniq_name, bool is_test)
: platform::MKLDNNHandlerT<T, dnnl::prelu_forward, dnnl::prelu_backward>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
Expand All @@ -46,8 +46,9 @@ class PReluMKLDNNHandler

this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
x_md, weights_md);
this->AcquireBackwardPrimitiveDescriptor(x_md, weights_md, x_md,
weights_md);
if (!is_test)
this->AcquireBackwardPrimitiveDescriptor(x_md, weights_md, x_md,
weights_md);
}
}

Expand Down Expand Up @@ -87,7 +88,7 @@ class PReluMKLDNNKernel : public framework::OpKernel<T> {
const bool is_test = ctx.Attr<bool>("is_test");

PReluMKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(), x,
alpha, ctx.InputName("X"));
alpha, ctx.InputName("X"), is_test);

auto src_memory_p = handler.AcquireSrcMemory(x);
auto weights_memory_p =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def test_check_grad(self):
self.check_grad_with_place(
core.CPUPlace(), ["X", "Alpha"],
"Out",
check_dygraph=False,
user_defined_grads=[self.dx, self.dalpha],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])

Expand Down
3 changes: 3 additions & 0 deletions python/paddle/fluid/tests/unittests/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,9 @@ def check_grad_with_place(self,
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
op_attrs = self.attrs if hasattr(self, "attrs") else dict()

if self.is_bfloat16_op():
check_dygraph = False

self._check_grad_helper()
if self.dtype == np.float64 and \
self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST:
Expand Down

0 comments on commit 5eb7bd0

Please sign in to comment.