Skip to content

Commit 9dadbb2

Browse files
Jingyue Wutensorflower-gardener
authored andcommitted
Take #2: Improve Conv2DBackpropInput to take input_sizes as a 2D shape.
When input_sizes is a 2D shape, the input batch size comes from output_grad and the input channel size comes from the filter. With this change, input_sizes is more likely to be a constant (e.g. even when the batch size is variable) so tf2tensorrt is able to convert more Conv2DBackpropInput to IDeconvolutionLayer. Changes to tf2tensorrt will come in separate CLs. I haven't made tf2xla support input_sizes being a 2D shape. It would error out for now. So we disabled the test added to conv_ops_test.py for XLA. PiperOrigin-RevId: 303217218 Change-Id: I283106657c00f49be41a74c7131bf8be787742a8
1 parent 55d96a7 commit 9dadbb2

File tree

9 files changed

+167
-23
lines changed

9 files changed

+167
-23
lines changed

tensorflow/compiler/tf2xla/kernels/conv_ops.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ class ConvBackpropInputOp : public XlaOpKernel {
107107
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape));
108108
xla::Shape input_shape =
109109
TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape);
110+
OP_REQUIRES(ctx, input_shape.rank() == attrs_.num_spatial_dims + 2,
111+
errors::InvalidArgument(
112+
"The rank of the specified input shape must be "
113+
"num_spatial_dims + 2. Expected ",
114+
attrs_.num_spatial_dims + 2, " got ", input_shape.rank()));
110115

111116
xla::StatusOr<xla::XlaOp> in_backprop =
112117
MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape,

tensorflow/core/framework/common_shape_fns.cc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,78 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
822822
return Status::OK();
823823
}
824824

825+
Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) {
826+
string data_format_str;
827+
if (!c->GetAttr("data_format", &data_format_str).ok()) {
828+
data_format_str = "NHWC";
829+
}
830+
TensorFormat data_format;
831+
if (!FormatFromString(data_format_str, &data_format)) {
832+
return errors::InvalidArgument("Invalid data format string: ",
833+
data_format_str);
834+
}
835+
836+
// For the rest of this function, output_grad_* describes out_backprop and
837+
// input_grad_* describes in_backprop.
838+
ShapeHandle output_grad_shape = c->input(2);
839+
TF_RETURN_IF_ERROR(c->WithRank(output_grad_shape, 4, &output_grad_shape));
840+
ShapeHandle filter_shape = c->input(1);
841+
TF_RETURN_IF_ERROR(c->WithRank(filter_shape, 4, &filter_shape));
842+
843+
DimensionHandle batch_size_dim;
844+
DimensionHandle output_grad_depth_dim;
845+
gtl::InlinedVector<DimensionHandle, 2> output_grad_spatial_dims(2);
846+
TF_RETURN_IF_ERROR(DimensionsFromShape(
847+
output_grad_shape, data_format, &batch_size_dim,
848+
absl::MakeSpan(output_grad_spatial_dims), &output_grad_depth_dim, c));
849+
DimensionHandle unused;
850+
TF_RETURN_IF_ERROR(
851+
c->Merge(output_grad_depth_dim, c->Dim(filter_shape, 3), &unused));
852+
853+
ShapeHandle specified_input_grad_shape;
854+
TF_RETURN_IF_ERROR(
855+
c->MakeShapeFromShapeTensor(0, &specified_input_grad_shape));
856+
if (c->Rank(specified_input_grad_shape) == InferenceContext::kUnknownRank) {
857+
TF_RETURN_IF_ERROR(c->WithRank(specified_input_grad_shape, 4,
858+
&specified_input_grad_shape));
859+
}
860+
861+
// input_grad_depth_dim doesn't equal c->Dim(filter_shape,2) when the number
862+
// of groups is larger than 1. If input_sizes is a 4D shape, we collect
863+
// input_grad_depth_dim from input_sizes; otherwise we compute it as
864+
// c->Dim(filter_shape,2).
865+
DimensionHandle input_grad_depth_dim;
866+
gtl::InlinedVector<DimensionHandle, 2> specified_input_grad_spatial_dims(2);
867+
int specified_input_grad_rank = c->Rank(specified_input_grad_shape);
868+
if (specified_input_grad_rank == 4) {
869+
DimensionHandle specified_batch_size_dim;
870+
TF_RETURN_IF_ERROR(DimensionsFromShape(
871+
specified_input_grad_shape, data_format, &specified_batch_size_dim,
872+
absl::MakeSpan(specified_input_grad_spatial_dims),
873+
&input_grad_depth_dim, c));
874+
TF_RETURN_IF_ERROR(
875+
c->Merge(specified_batch_size_dim, batch_size_dim, &unused));
876+
} else if (specified_input_grad_rank == 2) {
877+
specified_input_grad_spatial_dims[0] =
878+
c->Dim(specified_input_grad_shape, 0);
879+
specified_input_grad_spatial_dims[1] =
880+
c->Dim(specified_input_grad_shape, 1);
881+
input_grad_depth_dim = c->Dim(filter_shape, 2);
882+
} else {
883+
return errors::InvalidArgument(
884+
"Conv2DBackpropInput requires input_sizes to contain 4 values or 2 "
885+
"values, but got: ",
886+
specified_input_grad_rank);
887+
}
888+
889+
ShapeHandle input_grad_shape;
890+
TF_RETURN_IF_ERROR(ShapeFromDimensions(
891+
batch_size_dim, specified_input_grad_spatial_dims, input_grad_depth_dim,
892+
data_format, c, &input_grad_shape));
893+
c->set_output(0, input_grad_shape);
894+
return Status::OK();
895+
}
896+
825897
namespace {
826898

827899
Status DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext* c,

tensorflow/core/framework/common_shape_fns.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ Status DepthwiseConv2DNativeShapeWithExplicitPadding(
138138
// explicit padding.
139139
Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);
140140

141+
// Shape function for Conv2DBackpropInput.
142+
Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c);
143+
141144
// Shape function for AvgPool-like operations.
142145
Status AvgPoolShape(shape_inference::InferenceContext* c);
143146

tensorflow/core/kernels/conv_grad_input_ops.cc

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -450,14 +450,12 @@ class Conv2DBackpropInputOp : public OpKernel {
450450
const Tensor& input_sizes = context->input(0);
451451
const Tensor& filter = context->input(1);
452452
const Tensor& out_backprop = context->input(2);
453-
OP_REQUIRES(
454-
context, TensorShapeUtils::IsVector(input_sizes.shape()),
455-
errors::InvalidArgument(
456-
"Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
457-
input_sizes.dims()));
453+
458454
TensorShape input_shape;
459-
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
460-
input_sizes.vec<int32>(), &input_shape));
455+
OP_REQUIRES_OK(context,
456+
Conv2DBackpropComputeInputShape(input_sizes, filter.shape(),
457+
out_backprop.shape(),
458+
data_format_, &input_shape));
461459

462460
Tensor* in_backprop = nullptr;
463461
OP_REQUIRES_OK(context,
@@ -549,14 +547,12 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
549547
const Tensor& input_sizes = context->input(0);
550548
const Tensor& filter = context->input(1);
551549
const Tensor& out_backprop = context->input(2);
552-
OP_REQUIRES(
553-
context, TensorShapeUtils::IsVector(input_sizes.shape()),
554-
errors::InvalidArgument(
555-
"Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
556-
input_sizes.dims()));
550+
557551
TensorShape input_shape;
558-
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
559-
input_sizes.vec<int32>(), &input_shape));
552+
OP_REQUIRES_OK(context,
553+
Conv2DBackpropComputeInputShape(input_sizes, filter.shape(),
554+
out_backprop.shape(),
555+
data_format_, &input_shape));
560556

561557
ConvBackpropDimensions dims;
562558
OP_REQUIRES_OK(context,

tensorflow/core/kernels/conv_grad_shape_utils.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,35 @@ Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims,
166166
dims);
167167
}
168168

169+
Status Conv2DBackpropComputeInputShape(const Tensor& input_sizes,
170+
const TensorShape& filter_shape,
171+
const TensorShape& out_backprop_shape,
172+
const TensorFormat& data_format,
173+
TensorShape* input_shape) {
174+
if (!TensorShapeUtils::IsVector(input_sizes.shape())) {
175+
return errors::InvalidArgument(
176+
"Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
177+
input_sizes.dims());
178+
}
179+
180+
if (input_sizes.dim_size(0) == 4) {
181+
return TensorShapeUtils::MakeShape(input_sizes.vec<int32>(), input_shape);
182+
}
183+
184+
if (input_sizes.dim_size(0) == 2) {
185+
const int batch_size = GetTensorDim(out_backprop_shape, data_format, 'N');
186+
const int output_height = input_sizes.vec<int32>()(0);
187+
const int output_width = input_sizes.vec<int32>()(1);
188+
const int output_depth = filter_shape.dim_size(2);
189+
*input_shape = ShapeFromFormat(data_format, batch_size, output_height,
190+
output_width, output_depth);
191+
return Status::OK();
192+
}
193+
194+
return errors::InvalidArgument(
195+
"Conv2DBackpropInput requires input_sizes to "
196+
"contain 4 values or 2 values, but got: ",
197+
input_sizes.dim_size(0));
198+
}
199+
169200
} // namespace tensorflow

tensorflow/core/kernels/conv_grad_shape_utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,13 @@ Status ConvBackpropComputeDimensionsV2(
8383
const gtl::ArraySlice<int32>& dilations, const std::vector<int32>& strides,
8484
Padding padding, absl::Span<const int64> explicit_paddings,
8585
TensorFormat data_format, ConvBackpropDimensions* dims);
86+
87+
// Computes the shape of the in_backprop.
88+
Status Conv2DBackpropComputeInputShape(const Tensor& input_sizes,
89+
const TensorShape& filter_shape,
90+
const TensorShape& out_backprop_shape,
91+
const TensorFormat& data_format,
92+
TensorShape* input_shape);
8693
} // namespace tensorflow
8794

8895
#endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_SHAPE_UTILS_H_

tensorflow/core/ops/nn_ops.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -357,13 +357,7 @@ REGISTER_OP("Conv2DBackpropInput")
357357
.Attr(GetExplicitPaddingsAttrString())
358358
.Attr(GetConvnetDataFormatAttrString())
359359
.Attr("dilations: list(int) = [1, 1, 1, 1]")
360-
.SetShapeFn([](InferenceContext* c) {
361-
ShapeHandle s;
362-
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
363-
TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
364-
c->set_output(0, s);
365-
return Status::OK();
366-
});
360+
.SetShapeFn(shape_inference::Conv2DBackpropInputShape);
367361

368362
// TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a
369363
// more general string attribute ('kernel_impl'?) that can be used to

tensorflow/core/ops/nn_ops_test.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,25 @@ TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) {
320320
"[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[0];[0]");
321321
}
322322

323+
TEST(NNOpsTest, Conv2DBackpropInput_ShapeFn) {
324+
ShapeInferenceTestOp op("Conv2DBackpropInput");
325+
326+
// Test rank error.
327+
INFER_ERROR("input_sizes to contain 4 values or 2 values", op,
328+
"[3];[?,?,?,?];[?,?,?,?]");
329+
INFER_ERROR("Shape must be rank 4 but is rank 3", op,
330+
"[4];[?,?,?,?];[?,?,?]");
331+
332+
// When input_sizes is a 4D shape and the convolution is grouped, the channel
333+
// size of the input grad doesn't always equal the input channel size of the
334+
// filter. So, when input_sizes is a 4D shape, the channel size of the input
335+
// grad is determined by the content of input_sizes.
336+
INFER_OK(op, "[4];[?,?,2,?];[1,?,?,?]", "[d2_0,?,?,?]");
337+
// When input_sizes is a 2D shape, the channel size of the input grad always
338+
// matches the filter shape.
339+
INFER_OK(op, "[2];[?,?,2,?];[1,?,?,?]", "[d2_0,?,?,d1_2]");
340+
}
341+
323342
TEST(NNOpsTest, Conv3DBackpropInput_ShapeFn) {
324343
ShapeInferenceTestOp op("Conv3DBackpropInput");
325344

tensorflow/python/kernel_tests/conv_ops_test.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -836,8 +836,9 @@ def _RunAndVerifyBackpropInput(self,
836836
x2 = self._CreateNumpyTensor(output_sizes)
837837
dilations = list(dilations)
838838
with test_util.device(use_gpu):
839-
if data_format == "NCHW":
840-
input_sizes = test_util.NHWCToNCHW(input_sizes)
839+
if len(input_sizes) == 4:
840+
if data_format == "NCHW":
841+
input_sizes = test_util.NHWCToNCHW(input_sizes)
841842
t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)])
842843
t1 = constant_op.constant(x1, shape=filter_sizes)
843844
t2 = constant_op.constant(x2, shape=output_sizes)
@@ -1007,6 +1008,22 @@ def testConv2DKernelSizeMatchesInputSizeBackpropInput(self):
10071008
use_gpu=use_gpu,
10081009
err=1e-5)
10091010

1011+
@test_util.run_in_graph_and_eager_modes
1012+
@test_util.disable_xla("XLA requires input_sizes to be a 4D shape.")
1013+
def testConv2DInputSizesContainsOnlySpatialDimensionsBackpropInput(self):
1014+
expected_output = [5.0, 11.0, 17.0, 23.0]
1015+
for (data_format, use_gpu) in GetTestConfigs():
1016+
self._RunAndVerifyBackpropInput(
1017+
input_sizes=[2, 2],
1018+
filter_sizes=[2, 2, 1, 2],
1019+
output_sizes=[1, 1, 1, 2],
1020+
strides=[1, 1],
1021+
padding="VALID",
1022+
expected=expected_output,
1023+
data_format=data_format,
1024+
use_gpu=use_gpu,
1025+
err=1e-5)
1026+
10101027
# Testing for backprops
10111028
def _RunAndVerifyBackpropFilter(self,
10121029
input_sizes,

0 commit comments

Comments
 (0)