Skip to content

Commit f1efe51

Browse files
vkuzofacebook-github-bot
authored andcommitted
add quantized version of hardswish operator (pytorch#34820)
Summary: Pull Request resolved: pytorch#34820 Adds quantized version of hardswish, for common quantized operator coverage. Note: * we carry over scale and zero_point from the input to the output, because the range of the output is unbounded if x > 0 * we also skip the .out function to not allow the user to specify a custom scale+zp (flexible on this). Test Plan: ``` python test/test_quantized.py https://gist.github.com/vkuzo/f9b579315ed7f5fdb24839e3218d8465 ``` Imported from OSS Differential Revision: D20472905 fbshipit-source-id: 0f2a83e9f5f7b43485fa46caf30e756dc5d492a9
1 parent f3e9fa6 commit f1efe51

File tree

7 files changed

+130
-0
lines changed

7 files changed

+130
-0
lines changed

aten/src/ATen/native/native_functions.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -5682,9 +5682,15 @@
56825682
- func: hardswish(Tensor self) -> Tensor
56835683
use_c10_dispatcher: full
56845684
python_module: nn
5685+
dispatch:
5686+
CPU: hardswish
5687+
QuantizedCPU: quantized_hardswish
56855688

56865689
- func: hardswish_(Tensor(a!) self) -> Tensor(a!)
56875690
python_module: nn
5691+
dispatch:
5692+
CPU: hardswish_
5693+
QuantizedCPU: quantized_hardswish_
56885694

56895695
- func: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor
56905696
use_c10_dispatcher: full

aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp

+41
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,46 @@ void qclamp_kernel(
404404
});
405405
}
406406

407+
void qhardswish_kernel(const Tensor& qx, Tensor& qy) {
408+
const auto i_scale = qx.q_scale();
409+
const auto i_zero_point = qx.q_zero_point();
410+
411+
const auto o_scale = qy.q_scale();
412+
const auto o_zero_point = qy.q_zero_point();
413+
const float o_inv_scale = 1.0 / o_scale;
414+
415+
using fVec = Vec256<float>;
416+
fVec i_scale_vec(i_scale);
417+
fVec i_zero_point_vec(i_zero_point);
418+
fVec i_scale_neg_zp_premul_vec = i_scale_vec * i_zero_point_vec.neg();
419+
fVec zero_vec(0.0f);
420+
fVec three_vec(3.0f);
421+
fVec six_vec(6.0f);
422+
423+
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qhardswish", [&]() {
424+
using qVec = Vec256<scalar_t>;
425+
auto iter = TensorIterator::unary_op(qy, qx);
426+
cpu_kernel_vec(
427+
iter,
428+
[&](scalar_t value) -> scalar_t {
429+
const auto x = at::dequantize_val(i_scale, i_zero_point, value);
430+
const auto y = x * std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f;
431+
return at::quantize_val<scalar_t>(o_scale, o_zero_point, y);
432+
},
433+
[&](qVec value) -> qVec {
434+
auto value_dx = value.dequantize(i_scale_vec, i_zero_point_vec,
435+
i_scale_neg_zp_premul_vec);
436+
for (int idx = 0; idx < value_dx.size(); idx++) {
437+
value_dx[idx] = value_dx[idx] * vec256::minimum(
438+
vec256::maximum(value_dx[idx] + three_vec, zero_vec),
439+
six_vec
440+
) / six_vec;
441+
}
442+
return qVec::quantize(value_dx, o_scale, o_zero_point, o_inv_scale);
443+
});
444+
});
445+
}
446+
407447

408448
void qtanh_kernel(const Tensor& qx, Tensor& qy) {
409449
int64_t zero_point = qx.q_zero_point();
@@ -1506,6 +1546,7 @@ REGISTER_DISPATCH(qsigmoid_stub, &qsigmoid_kernel);
15061546
REGISTER_DISPATCH(qhardsigmoid_stub, &qhardsigmoid_kernel);
15071547
REGISTER_DISPATCH(qclamp_stub, &qclamp_kernel);
15081548
REGISTER_DISPATCH(qtanh_stub, &qtanh_kernel);
1549+
REGISTER_DISPATCH(qhardswish_stub, &qhardswish_kernel);
15091550
REGISTER_DISPATCH(qelu_stub, &qelu_kernel);
15101551
REGISTER_DISPATCH(qadd_relu_stub, &qadd_kernel<true>);
15111552
REGISTER_DISPATCH(qadd_stub, &qadd_kernel<false>);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/NativeFunctions.h>
3+
#include <ATen/core/op_registration/op_registration.h>
4+
#include <ATen/quantized/Quantizer.h>
5+
#include <ATen/native/quantized/cpu/quantized_ops.h>
6+
7+
#include <algorithm>
8+
9+
namespace at {
10+
namespace native {
11+
12+
DEFINE_DISPATCH(qhardswish_stub);
13+
14+
Tensor quantized_hardswish(const Tensor& qx) {
15+
Tensor qy = at::_empty_affine_quantized(qx.sizes(), qx.options(),
16+
qx.q_scale(), qx.q_zero_point());
17+
qhardswish_stub(qx.device().type(), qx, qy);
18+
return qy;
19+
}
20+
21+
Tensor& quantized_hardswish_(Tensor& qx) {
22+
qhardswish_stub(qx.device().type(), qx, qx);
23+
return qx;
24+
}
25+
26+
}} // namespace at::native

aten/src/ATen/native/quantized/cpu/quantized_ops.h

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ using qbinary_fn =
2424
void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Tensor& /*other*/);
2525
using qadd_scalar_fn =
2626
void (*)(Tensor& /*out*/, const Tensor& /*self*/, Scalar other /*other*/);
27+
using qhardswish_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
2728
using qmaxpool_2d_fn = void (*)(
2829
const Tensor& qx,
2930
int64_t iC, // input/output channels
@@ -131,6 +132,7 @@ DECLARE_DISPATCH(qbinary_fn, qmul_stub);
131132
DECLARE_DISPATCH(qbinary_fn, qmul_relu_stub);
132133
DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_stub);
133134
DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_relu_stub);
135+
DECLARE_DISPATCH(qhardswish_fn, qhardswish_stub);
134136
DECLARE_DISPATCH(qelu_fn, qelu_stub);
135137
DECLARE_DISPATCH(qmaxpool_2d_fn, qmaxpool_2d_nhwc_stub);
136138
DECLARE_DISPATCH(qadaptive_avg_pool2d_fn, qadaptive_avg_pool2d_nhwc_stub);

benchmarks/operator_benchmark/pt/qactivation_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
('relu', nnq.ReLU),
4646
('relu6', nnq.ReLU6),
4747
('functional.hardtanh', nnq.functional.hardtanh),
48+
('functional.hardswish', nnq.functional.hardswish),
4849
('functional.elu', nnq.functional.elu),
4950
('functional.hardsigmoid', nnq.functional.hardsigmoid),
5051
),

test/test_quantized.py

+30
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,36 @@ def test_hardtanh(self, X, min_val, max_val):
364364
op_(qY_hat, min_val, max_val, inplace=True)
365365
self.assertEqual(qY, qY_hat, message="{} hardtanh failed".format(name))
366366

367+
"""Tests the correctness of the quantized::hardswish op."""
368+
@given(X=hu.tensor(shapes=hu.array_shapes(1, 8, 1, 8),
369+
elements=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False),
370+
qparams=hu.qparams()))
371+
def test_hardswish(self, X):
372+
X, (scale, zero_point, torch_type) = X
373+
X = torch.from_numpy(X)
374+
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
375+
dtype=torch_type)
376+
dqX = qX.dequantize()
377+
378+
output_scale = scale
379+
output_zero_point = zero_point
380+
381+
dqY_hat = F.hardswish(dqX)
382+
qY_hat = torch.quantize_per_tensor(dqY_hat, scale=output_scale,
383+
zero_point=output_zero_point,
384+
dtype=torch_type)
385+
386+
# regular
387+
qY = torch.nn.quantized.functional.hardswish(qX)
388+
self.assertEqual(qY, qY_hat,
389+
message="Hardswish failed: {} vs {}".format(qY, qY_hat))
390+
391+
# inplace
392+
qX_copy = qX.clone().detach()
393+
torch.nn.quantized.functional.hardswish(qX_copy, inplace=True)
394+
self.assertEqual(qX_copy, qY_hat,
395+
message="inplace Hardswish failed: {} vs {}".format(qY, qY_hat))
396+
367397
"""Tests the correctness of the scalar addition."""
368398
@given(A=hu.tensor(shapes=hu.array_shapes(1, 4, 1, 5),
369399
elements=hu.floats(-1e6, 1e6, allow_nan=False),

torch/nn/quantized/functional.py

+24
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,30 @@ def hardtanh(input, min_val=-1., max_val=1., inplace=False):
368368
return torch._C._nn.hardtanh_(input, min_val, max_val)
369369
return torch._C._nn.hardtanh(input, min_val, max_val)
370370

371+
def hardswish(input, inplace=False):
372+
r"""Applies the quantized version of the hardswish function, element-wise,
373+
as described in the paper:
374+
375+
`Searching for MobileNetV3`_.
376+
377+
.. math::
378+
\text{Hardswish}(x) = x * \frac{ReLU6(x + 3)}{6}
379+
380+
Args:
381+
input: quantized input
382+
inplace: Inplace modification of the input tensor
383+
384+
See :class:`~torch.nn.Hardswish` for more details.
385+
386+
.. _`Searching for MobileNetV3`:
387+
https://arxiv.org/abs/1905.02244
388+
"""
389+
if not input.is_quantized:
390+
raise ValueError("Input to 'quantized.hardswish' must be quantized!")
391+
if inplace:
392+
return torch._C._nn.hardswish_(input)
393+
return torch._C._nn.hardswish(input)
394+
371395
def elu(input, alpha=1., inplace=False, scale=None, zero_point=None):
372396
r"""
373397
Applies the quantized ELU function element-wise:

0 commit comments

Comments
 (0)