Skip to content

Commit a962568

Browse files
authored
Merge pull request #2 from tensorflow/master
Merge tensorflow contributions
2 parents 753bd31 + 6783e38 commit a962568

File tree

24 files changed

+912
-425
lines changed

24 files changed

+912
-425
lines changed

tensorflow/compiler/mlir/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ cc_library(
2929
srcs = ["op_or_arg_name_mapper.cc"],
3030
hdrs = ["op_or_arg_name_mapper.h"],
3131
deps = [
32+
"@com_google_absl//absl/strings",
3233
"@llvm//:support",
3334
"@local_config_mlir//:IR",
3435
],

tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc

Lines changed: 30 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,14 @@ static Type GetQuantizedType(Builder builder, Type input_type,
4545
quant::ExpressedToQuantizedConverter::forInputType(input_type);
4646

4747
quant::QuantizedType quantizedEleType;
48-
if (min.size() == 1 && max.size() == 1) {
48+
if (min.size() == 1 && max.size() == 1 && quant_dim == -1) {
4949
quantizedEleType = quant::fakeQuantAttrsToType(
5050
builder.getUnknownLoc(), storage_type_width, min[0], max[0],
5151
narrow_range, converter.expressedType, is_signed);
5252
} else if (min.size() == max.size()) {
5353
auto shape = input_type.dyn_cast<ShapedType>();
54-
if (!shape || min.size() != shape.getDimSize(quant_dim)) {
54+
if (!shape || shape.getRank() <= quant_dim ||
55+
min.size() != shape.getDimSize(quant_dim)) {
5556
return {};
5657
}
5758
// TODO(b/141508873): the quantization dim is set to the last dimension.
@@ -92,33 +93,39 @@ TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
9293
Type final_type =
9394
GetQuantizedType(builder, input_type, min_value, max_value, quant_dim,
9495
num_bits.getInt(), narrow_range.getValue(), is_signed);
96+
if (!final_type) return {};
9597
return TypeAttr::get(final_type);
9698
}
9799

98-
// TODO(fengliuai): expose the `quant_dim` argument.
99-
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
100-
Attribute max, IntegerAttr num_bits,
101-
BoolAttr narrow_range, bool is_signed) {
102-
// When input_type isn't a ranked shaped type, it shouldn't be per-axis
103-
// quantizatied, and `quant_dim` shouldn't be used, otherwise, set it to the
104-
// last dimension.
105-
int quant_dim = 0;
106-
if (auto shape = input_type.dyn_cast<RankedTensorType>()) {
107-
quant_dim = shape.getRank() - 1;
100+
// Repeats the content of `data` multiple times to resize to `target_size`.
101+
// Note that this only broadcast across one dimension.
102+
template <typename T>
103+
static bool BroadcastVector(int target_size, SmallVectorImpl<T>& data) {
104+
int size = data.size();
105+
if (size != target_size) {
106+
if (target_size % size != 0) return true;
107+
data.reserve(target_size);
108+
for (int i = 1, e = target_size / size; i != e; ++i) {
109+
data.insert(data.end(), data.begin(), data.begin() + size);
110+
}
108111
}
109-
return GetQuantizedTypeAttr(builder, input_type, min, max, quant_dim,
110-
num_bits, narrow_range, is_signed);
112+
return false;
111113
}
112114

113115
// Changes the axis of the input per-channel quantized type to match the
114116
// dimension of the target type. Returns nullptr if it fails.
115117
static quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast(
116-
quant::UniformQuantizedPerAxisType qtype, Type target, int axis) {
117-
auto shaped = target.dyn_cast<ShapedType>();
118+
quant::UniformQuantizedPerAxisType qtype, Type target, int quant_dim) {
119+
auto shaped = target.dyn_cast<RankedTensorType>();
118120
if (!shaped) return {};
119121

120-
// Broadcast the scales and zero points to match the length of the axis-th
121-
// dimension of the target type. Currently, it covers two cases:
122+
SmallVector<double, 4> scales(qtype.getScales().begin(),
123+
qtype.getScales().end());
124+
SmallVector<int64_t, 4> zero_points(qtype.getZeroPoints().begin(),
125+
qtype.getZeroPoints().end());
126+
// Broadcast the scales and zero points to match the target size, which is
127+
// usually the axis-th dimension of the target type. Currently, it covers two
128+
// cases:
122129
// - for Transpose, the data layout is changed so the `dim[axis]` still equals
123130
// to the `scales_size`. The broadcast skips;
124131
// - for Reshape, the data layout isn't changed but the innermost dimension is
@@ -127,33 +134,13 @@ static quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast(
127134
//
128135
// TODO(b/141709944): after the fix, the `scales` can be for dim[2], thus we
129136
// have to repeat each elements in the `scales` locally dim[3] times.
130-
auto scales = qtype.getScales();
131-
auto zero_points = qtype.getZeroPoints();
132-
int target_size = shaped.getDimSize(axis);
133-
int scales_size = scales.size();
134-
int zero_points_size = zero_points.size();
135-
136-
SmallVector<double, 4> new_scales;
137-
SmallVector<int64_t, 4> new_zero_points;
138-
if (scales_size != target_size) {
139-
if (target_size % scales_size != 0) return {};
140-
for (int i = 0, e = target_size / scales_size; i != e; ++i) {
141-
new_scales.insert(new_scales.end(), scales.begin(), scales.end());
142-
}
143-
scales = new_scales;
137+
if (BroadcastVector<double>(shaped.getDimSize(quant_dim), scales) ||
138+
BroadcastVector<int64_t>(shaped.getDimSize(quant_dim), zero_points)) {
139+
return {};
144140
}
145-
if (zero_points_size != target_size) {
146-
if (target_size % zero_points_size != 0) return {};
147-
for (int i = 0, e = target_size / zero_points_size; i != e; ++i) {
148-
new_zero_points.insert(new_zero_points.end(), zero_points.begin(),
149-
zero_points.end());
150-
}
151-
zero_points = new_zero_points;
152-
}
153-
154141
return quant::UniformQuantizedPerAxisType::get(
155142
qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
156-
scales, zero_points, axis, qtype.getStorageTypeMin(),
143+
scales, zero_points, quant_dim, qtype.getStorageTypeMin(),
157144
qtype.getStorageTypeMax());
158145
}
159146

@@ -208,7 +195,7 @@ Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, unsigned num_bits,
208195
}
209196
}
210197
auto type =
211-
GetQuantizedType(builder, attr.getType(), min, max, /*quant_dim=*/0,
198+
GetQuantizedType(builder, attr.getType(), min, max, /*quant_dim=*/-1,
212199
num_bits, narrow_range, is_signed);
213200
if (auto ele_type = type.dyn_cast_or_null<TensorType>())
214201
return ele_type.getElementType();

tensorflow/compiler/mlir/lite/quantization/quantization_utils.h

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -307,22 +307,16 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
307307
// `narrow_range` is set to true for weights and `is_signed` is set to true
308308
// if it is using signed int symmetric quantization.
309309
//
310-
// Note that this method doesn't modify min and max, so they needs to be
311-
// adjusted before calling this method if symmetric quantized type needs to be
312-
// returned.
310+
// Note that this method may broadcast min and max to match the dimension length
311+
// of `input_type`, if the the `quant_dim` is valid. On the other hand, the
312+
// symmetry of min and max is not adjusted by this method. The QAT workflow
313+
// should set min/max correctly (and use `narrow_range`=true, `is_signed`=true)
314+
// if symmetric quantization is required.
313315
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
314316
Attribute max, int quant_dim,
315317
IntegerAttr num_bits, BoolAttr narrow_range,
316318
bool is_signed);
317319

318-
// Same above, but the `channel_dim` is hardcoded to the last dimension to match
319-
// the behavior of tf.FakeQuantWithMinMaxVarsPerChannel. This method is called
320-
// when converting tf.FakeQuant* ops to MLIR's quant parameter representation,
321-
// aka. quant::QuantType.
322-
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
323-
Attribute max, IntegerAttr num_bits,
324-
BoolAttr narrow_range, bool is_signed);
325-
326320
// Casts the `target` type to a quantized type by using the quantization
327321
// parameters from the type in the `source` type attribute.
328322
// Examples:

tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ glob_lit_tests(
88
":test_utilities",
99
],
1010
driver = "@local_config_mlir//:run_lit.sh",
11-
test_file_exts = ["pbtxt"],
11+
test_file_exts = [
12+
"pbtxt",
13+
"py",
14+
],
1215
)
1316

1417
# Bundle together all the debug info files that are used by the tests.
@@ -24,8 +27,30 @@ filegroup(
2427
name = "test_utilities",
2528
testonly = True,
2629
data = [
30+
":concrete_function_error",
31+
":saved_model_error",
2732
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
2833
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
2934
"@llvm//:FileCheck",
3035
],
3136
)
37+
38+
py_binary(
39+
name = "saved_model_error",
40+
srcs = ["saved_model_error.py"],
41+
main = "saved_model_error.py",
42+
python_version = "PY3",
43+
deps = [
44+
"//tensorflow:tensorflow_py",
45+
],
46+
)
47+
48+
py_binary(
49+
name = "concrete_function_error",
50+
srcs = ["concrete_function_error.py"],
51+
main = "concrete_function_error.py",
52+
python_version = "PY3",
53+
deps = [
54+
"//tensorflow:tensorflow_py",
55+
],
56+
)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Test file to display the error message and verify it with FileCheck."""
16+
17+
# RUN: %p/concrete_function_error | FileCheck %s
18+
19+
from __future__ import absolute_import
20+
from __future__ import division
21+
from __future__ import print_function
22+
23+
import sys
24+
from absl import app
25+
26+
from tensorflow import enable_v2_behavior
27+
import tensorflow.compat.v2 as tf
28+
29+
enable_v2_behavior()
30+
31+
32+
class TestGraphDebugInfo(object):
33+
"""Test stack trace can be displayed."""
34+
35+
def testConcreteFunctionDebugInfo(self):
36+
"""Create a concrete func with unsupported ops, and convert it."""
37+
@tf.function(
38+
input_signature=[tf.TensorSpec(shape=[3, 3], dtype=tf.float32)])
39+
def model(x):
40+
y = tf.math.reciprocal(x) # Not supported
41+
return y + y
42+
43+
func = model.get_concrete_function()
44+
converter = tf.lite.TFLiteConverter.from_concrete_functions([func])
45+
converter.experimental_new_converter = True
46+
converter.convert()
47+
48+
# pylint: disable=line-too-long
49+
50+
# CHECK-LABEL: testConcreteFunctionDebugInfo
51+
# CHECK: error: 'tf.Reciprocal' op is neither a custom op nor a flex op
52+
# CHECK: attrs=attr_protos, op_def=op_def)
53+
# CHECK: ^
54+
# CHECK: {{.*tensorflow/python/ops/gen_math_ops.py:[0-9]+:[0-9]+: note: called from}}
55+
# CHECK: "Reciprocal", x=x, name=name)
56+
# CHECK: ^
57+
# CHECK: {{.*tensorflow/compiler/mlir/lite/tests/debuginfo/concrete_function_error.py:[0-9]+:[0-9]+: note: called from}}
58+
# CHECK: y = tf.math.reciprocal(x) # Not supported
59+
# CHECK: ^
60+
# CHECK: <unknown>:0: error: failed while converting: 'main'
61+
62+
# pylint: enable=line-too-long
63+
64+
65+
def main(argv):
66+
if len(argv) > 1:
67+
raise app.UsageError('Too many command-line arguments.')
68+
69+
try:
70+
TestGraphDebugInfo().testConcreteFunctionDebugInfo()
71+
except Exception as e: # pylint: disable=broad-except
72+
sys.stdout.write('testConcreteFunctionDebugInfo')
73+
sys.stdout.write(str(e))
74+
75+
76+
if __name__ == '__main__':
77+
app.run(main)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Test file to display the error message and verify it with FileCheck."""
16+
17+
# RUN: %p/saved_model_error | FileCheck %s
18+
19+
from __future__ import absolute_import
20+
from __future__ import division
21+
from __future__ import print_function
22+
23+
import sys
24+
from absl import app
25+
26+
from tensorflow import enable_v2_behavior
27+
import tensorflow.compat.v2 as tf
28+
29+
enable_v2_behavior()
30+
31+
32+
class TestModule(tf.Module):
33+
"""The test model has supported op."""
34+
35+
@tf.function(input_signature=[tf.TensorSpec(shape=[3, 3], dtype=tf.float32)])
36+
def model(self, x):
37+
y = tf.math.reciprocal(x) # Not supported
38+
return y + y
39+
40+
41+
class TestGraphDebugInfo(object):
42+
"""Test stack trace can be displayed."""
43+
44+
def testSavedModelDebugInfo(self):
45+
"""Save a saved model with unsupported ops, and then load and convert it."""
46+
# saved the model
47+
test_model = TestModule()
48+
saved_model_path = '/tmp/test.saved_model'
49+
save_options = tf.saved_model.SaveOptions(save_debug_info=True)
50+
tf.saved_model.save(test_model, saved_model_path, options=save_options)
51+
52+
# load the model and convert
53+
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)
54+
converter.experimental_new_converter = True
55+
converter.convert()
56+
57+
# pylint: disable=line-too-long
58+
59+
# CHECK-LABEL: testSavedModelDebugInfo
60+
# CHECK: error: 'tf.Reciprocal' op is neither a custom op nor a flex op
61+
# CHECK: attrs=attr_protos, op_def=op_def)
62+
# CHECK: ^
63+
# CHECK: {{.*tensorflow/python/ops/gen_math_ops.py:[0-9]+:[0-9]+: note: called from}}
64+
# CHECK: "Reciprocal", x=x, name=name)
65+
# CHECK: ^
66+
# CHECK: {{.*tensorflow/compiler/mlir/lite/tests/debuginfo/saved_model_error.py:[0-9]+:[0-9]+: note: called from}}
67+
# CHECK: y = tf.math.reciprocal(x) # Not supported
68+
# CHECK: ^
69+
# CHECK: <unknown>:0: error: failed while converting: 'main'
70+
71+
# pylint: enable=line-too-long
72+
73+
74+
def main(argv):
75+
"""test driver method writes the error message to stdout."""
76+
if len(argv) > 1:
77+
raise app.UsageError('Too many command-line arguments.')
78+
79+
try:
80+
TestGraphDebugInfo().testSavedModelDebugInfo()
81+
except Exception as e: # pylint: disable=broad-except
82+
sys.stdout.write('testSavedModelDebugInfo')
83+
sys.stdout.write(str(e))
84+
85+
86+
if __name__ == '__main__':
87+
app.run(main)

tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,8 @@ func @perChannelFakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x3
296296

297297
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<16xf32>
298298
// CHECK: %[[CONSTANT0:.*]] = constant dense<0.000000e+00> : tensor<16x3x3x3xf32>
299-
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<16x3x3x3x!quant.uniform<u8:f32:0, {1.000000e+00
299+
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<16x3x3x3x!quant.uniform<u8:f32:0,
300+
// CHECK-SAME: {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00}>>
300301
// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]])
301302
// CHECK: %[[CONV:.*]] = "tfl.conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]])
302303
// CHECK: return %[[CONV]] : tensor<256x30x30x16xf32>
@@ -336,7 +337,10 @@ func @perChannelFakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor
336337

337338
// CHECK: %[[CONSTANT:.*]] = constant dense<0.000000e+00> : tensor<48xf32>
338339
// CHECK: %[[CONSTANT0:.*]] = constant dense<0.000000e+00> : tensor<1x3x3x48xf32>
339-
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<1x3x3x48x!quant.uniform<u8:f32:3, {1.000000e+00
340+
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CONSTANT0]]) {qtype = tensor<1x3x3x48x!quant.uniform<u8:f32:3,
341+
// CHECK-SAME: {1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,
342+
// CHECK-SAME: 1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,
343+
// CHECK-SAME: 1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00,1.000000e+00}>>}
340344
// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]])
341345
// CHECK: %[[CONV:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[DEQUANTIZE]], %[[CONSTANT]])
342346
// CHECK: return %[[CONV]]

tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
3939
// Use the tensor type information from $0 and convert min $1, max $2 and
4040
// numBits $3 and narrowRange $4 to a QuantizedType.
4141
def ConvertToQuantTypeFromAttrs : NativeCodeCall<
42-
"GetQuantizedTypeAttr($_builder, $0->getType(), $1, $2, $3, $4, /*is_signed=*/false)">;
42+
"GetQuantizedTypeAttr($_builder, $0->getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">;
4343

4444
// Converts an integer attribute $0 to 32-bit with builder.
4545
def convertIntAttrTo32Bit : NativeCodeCall<

0 commit comments

Comments
 (0)