Skip to content

Commit 2640415

Browse files
fufeisimehtanirav
authored andcommitted
Create a quantized non-in-palce version CUDA ReLU function, (#85669)
Summary: this and #85670 are to allow the relu function to run on a quantized tensor on cuda. That is torch.relu(qa) for a quantized tensor qa on cuda. Test Plan: python test/test_quantization.py Previous PR that has been reverted: #85502. Pull Request resolved: #85669 Approved by: https://github.com/dzdang
1 parent 7f52c62 commit 2640415

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4319,6 +4319,7 @@
43194319
MPS: relu_mps
43204320
MkldnnCPU: mkldnn_relu
43214321
QuantizedCPU: relu_quantized_cpu
4322+
QuantizedCUDA: relu_quantized_cuda
43224323
NestedTensorCPU, NestedTensorCUDA: NestedTensor_relu
43234324

43244325
- func: relu_(Tensor(a!) self) -> Tensor(a!)

aten/src/ATen/native/quantized/cuda/Activation.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <c10/util/Exception.h>
22
#include <ATen/ATen.h>
3+
#include <ATen/Functions.h>
34

45
namespace at {
56
namespace native {
@@ -17,5 +18,13 @@ Tensor gelu_quantized_cuda(const Tensor& qx, c10::string_view approximate) {
1718
return at::quantize_per_tensor(result_fp32, qx.q_scale(), qx.q_zero_point(), qx.scalar_type());
1819
}
1920

21+
Tensor relu_quantized_cuda(const Tensor& self) {
22+
auto zero_point = self.q_zero_point();
23+
auto int_repr = self.int_repr();
24+
auto mask = (int_repr > zero_point);
25+
const auto relu_int_repr = at::where(mask, int_repr, zero_point);
26+
return at::_make_per_tensor_quantized_tensor(relu_int_repr, self.q_scale(), zero_point);
27+
}
28+
2029
} // namespace at::native
2130
} // namespace at

test/quantization/core/test_quantized_op.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ def _test_activation_function(self, X, fn_name, test_configs):
168168
X, (scale, zero_point, torch_type) = X
169169
if not isinstance(X, torch.Tensor):
170170
X = torch.from_numpy(X)
171+
if (X.device.type == 'cuda') and (torch.backends.quantized.engine == 'qnnpack'):
172+
return
171173
# Quantizes the reference to account for max error.
172174
# q_min and q_max only depend on the initial torch_type.
173175
q_min, q_max = torch.iinfo(torch_type).min, torch.iinfo(torch_type).max
@@ -229,9 +231,7 @@ def _test_activation_function(self, X, fn_name, test_configs):
229231

230232
"""Tests the correctness of the quantized::relu op."""
231233
@override_qengines
232-
@given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
233-
qparams=hu.qparams()))
234-
def test_qrelu(self, X):
234+
def test_qrelu(self):
235235
relu_test_configs = [
236236
{
237237
'quantized_fn': [
@@ -253,7 +253,29 @@ def test_qrelu(self, X):
253253
}
254254
}
255255
]
256-
self._test_activation_function(X, 'relu', relu_test_configs)
256+
devices = ["cpu", "cuda"] if TEST_CUDA else ["cpu"]
257+
for device in devices:
258+
# Only test the non-in-place version relu quantized cuda,
259+
# will remove this when creating in-place version relu quantized cuda.
260+
if device == 'cuda':
261+
relu_test_configs = [
262+
{
263+
'quantized_fn': [
264+
torch.relu,
265+
torch.nn.functional.relu,
266+
],
267+
'reference_fn': torch.nn.functional.relu
268+
},
269+
]
270+
shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4))
271+
dtypes = (torch.quint8, torch.qint8)
272+
scales = (0.05, 0.1)
273+
zero_points = (0, 5)
274+
test_cases = itertools.product(shapes, dtypes, scales, zero_points)
275+
for shape, dtype, scale, zero_point in test_cases:
276+
X = torch.randn(*shape, device=device)
277+
X = (X, (scale, zero_point, dtype))
278+
self._test_activation_function(X, 'relu', relu_test_configs)
257279

258280
"""Tests the correctness of the quantized::relu6 op."""
259281
def test_qrelu6(self):

0 commit comments

Comments
 (0)