Skip to content

Commit 83c5225

Browse files
liufengdbtensorflower-gardener
authored andcommitted
Add the tf.FakeQuantWithMinMaxVarPerChannel op
PiperOrigin-RevId: 268082252
1 parent 2417ede commit 83c5225

File tree

3 files changed

+110
-0
lines changed

3 files changed

+110
-0
lines changed

tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,52 @@ values.
11201120
}];
11211121
}
11221122

1123+
def TF_FakeQuantWithMinMaxVarsPerChannelOp : TF_Op<"FakeQuantWithMinMaxVarsPerChannel", [NoSideEffect]> {
1124+
let summary = [{
1125+
Fake-quantize the 'inputs' tensor of type float and one of the shapes: `[d]`,
1126+
}];
1127+
1128+
let description = [{
1129+
`[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]`
1130+
to 'outputs' tensor of same shape as `inputs`.
1131+
1132+
`[min; max]` define the clamping range for the `inputs` data.
1133+
`inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]`
1134+
when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and
1135+
then de-quantized and output as floats in `[min; max]` interval.
1136+
`num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive.
1137+
1138+
Before quantization, `min` and `max` values are adjusted with the following
1139+
logic.
1140+
It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values,
1141+
the behavior can be unexpected:
1142+
If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`.
1143+
If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`.
1144+
If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `,
1145+
`min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`.
1146+
1147+
This operation has a gradient and thus allows for training `min` and `max`
1148+
values.
1149+
}];
1150+
1151+
let arguments = (ins
1152+
F32Tensor:$inputs,
1153+
F32Tensor:$min,
1154+
F32Tensor:$max,
1155+
1156+
DefaultValuedAttr<I64Attr, "8">:$num_bits,
1157+
DefaultValuedAttr<BoolAttr, "false">:$narrow_range
1158+
);
1159+
1160+
let results = (outs
1161+
F32Tensor:$outputs
1162+
);
1163+
1164+
let verifier = [{
1165+
return Verify(*this);
1166+
}];
1167+
}
1168+
11231169
def TF_FillOp : TF_Op<"Fill", [NoSideEffect]> {
11241170
let summary = "Creates a tensor filled with a scalar value.";
11251171

tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ limitations under the License.
3434
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
3535
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
3636
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
37+
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
3738
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
3839
#include "mlir/IR/Types.h" // TF:local_config_mlir
3940
#include "mlir/IR/Value.h" // TF:local_config_mlir
@@ -308,6 +309,38 @@ static LogicalResult Verify(FakeQuantWithMinMaxVarsOp op) {
308309
return success();
309310
}
310311

312+
//===----------------------------------------------------------------------===//
313+
// FakeQuantWithMinMaxVarsPerChannelOp
314+
//===----------------------------------------------------------------------===//
315+
static LogicalResult Verify(FakeQuantWithMinMaxVarsPerChannelOp op) {
316+
if (!isOfRankedFloatTensorType(op.min(), 1))
317+
return op.emitOpError("requires min to be a 1d float tensor");
318+
319+
if (!isOfRankedFloatTensorType(op.max(), 1))
320+
return op.emitOpError("requires max to be a 1d float tensor");
321+
322+
Value *inputs = op.inputs();
323+
if (!HasRankAtLeast(inputs, 1) ||
324+
inputs->getType().isa<UnrankedTensorType>()) {
325+
return op.emitError("requires inputs to be at least 1d float tensor");
326+
}
327+
328+
auto inputsType = inputs->getType().cast<ShapedType>();
329+
int depth = inputsType.getDimSize(inputsType.getRank() - 1);
330+
if (op.min()->getType().cast<ShapedType>().getDimSize(0) != depth ||
331+
op.max()->getType().cast<ShapedType>().getDimSize(0) != depth) {
332+
return op.emitOpError(
333+
"requires min and max to have same size as last dimension of inputs");
334+
}
335+
336+
int64_t num_bits = op.num_bits().getSExtValue();
337+
if (num_bits < 2 || num_bits > 16) {
338+
return op.emitOpError(
339+
"requires num_bits to be between 2 and 16, inclusive");
340+
}
341+
return success();
342+
}
343+
311344
//===----------------------------------------------------------------------===//
312345
// FusedBatchNormOp
313346
//===----------------------------------------------------------------------===//

tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,37 @@ func @testInvalidFakeQuantWithMinMaxVarsWrongMaxType(tensor<8x8x8x8xf32>, tensor
459459

460460
// -----
461461

462+
// Test valid tf.FakeQuantWithMinMaxVarsPerChannel
463+
// CHECK-LABEL: func @FakeQuantWithMinMaxVarsPerChannel
464+
func @FakeQuantWithMinMaxVarsPerChannel(tensor<1x2x3x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32> {
465+
^bb0(%arg0: tensor<1x2x3x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>):
466+
// CHECK: "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) : (tensor<1x2x3x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32>
467+
%0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) : (tensor<1x2x3x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32>
468+
return %0 : tensor<1x2x3x8xf32>
469+
}
470+
471+
// -----
472+
473+
// Test invalid tf.FakeQuantWithMinMaxVarsPerChannel
474+
func @FakeQuantWithMinMaxVarsPerChannel_ranked_inputs(tensor<f32>, tensor<8xf32>, tensor<8xf32>) -> tensor<f32> {
475+
^bb0(%arg0: tensor<f32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>):
476+
// expected-error @+1 {{requires inputs to be at least 1d float tensor}}
477+
%0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<8xf32>, tensor<8xf32>) -> tensor<f32>
478+
return %0 : tensor<f32>
479+
}
480+
481+
// -----
482+
483+
// Test invalid tf.FakeQuantWithMinMaxVarsPerChannel
484+
func @FakeQuantWithMinMaxVarsPerChannel_mismatch_min_max(tensor<1x2x3x8xf32>, tensor<1xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32> {
485+
^bb0(%arg0: tensor<1x2x3x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<8xf32>):
486+
// expected-error @+1 {{requires min and max to have same size as last dimension of inputs}}
487+
%0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) : (tensor<1x2x3x8xf32>, tensor<1xf32>, tensor<8xf32>) -> tensor<1x2x3x8xf32>
488+
return %0 : tensor<1x2x3x8xf32>
489+
}
490+
491+
// -----
492+
462493
// Test valid tf.FusedBatchNorm
463494
// CHECK-LABEL: func @testFusedBatchNorm
464495
func @testFusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> {

0 commit comments

Comments
 (0)