Skip to content

Commit 0a7d8b4

Browse files
fufeisipytorchmergebot
authored andcommitted
Create a quantized in-palce version CUDA ReLU function, relu_quantized_cuda_. (pytorch#85670)
Summary: this and pytorch#85669 are to allow the relu function to run on a quantized tensor on cuda. That is torch.relu(qa) for a quantized tensor qa on cuda. Test Plan: python test/test_quantization.py Previous PR that has been reverted: pytorch#85502. Pull Request resolved: pytorch#85670 Approved by: https://github.com/dzdang, https://github.com/z-a-f
1 parent eb650ab commit 0a7d8b4

File tree

3 files changed

+22
-12
lines changed

3 files changed

+22
-12
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4330,6 +4330,7 @@
43304330
MPS: relu_mps_
43314331
MkldnnCPU: mkldnn_relu_
43324332
QuantizedCPU: relu_quantized_cpu_
4333+
QuantizedCUDA: relu_quantized_cuda_
43334334
NestedTensorCPU, NestedTensorCUDA: NestedTensor_relu_
43344335
autogen: relu.out
43354336

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/native/TensorIterator.h>
3+
#include <ATen/native/cuda/Loops.cuh>
4+
5+
namespace at {
6+
namespace native {
7+
8+
Tensor& relu_quantized_cuda_(Tensor& self) {
9+
const auto zero_point = self.q_zero_point();
10+
AT_DISPATCH_QINT_TYPES(
11+
self.scalar_type(), "qrelu_cuda", [&]() {
12+
auto iter = TensorIterator::unary_op(self, self);
13+
gpu_kernel(iter, [zero_point] GPU_LAMBDA(scalar_t value) -> scalar_t {
14+
return scalar_t(std::max<underlying_t>(value.val_, zero_point));
15+
});
16+
});
17+
return self;
18+
}
19+
20+
} // namespace at::native
21+
} // namespace at

test/quantization/core/test_quantized_op.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -255,18 +255,6 @@ def test_qrelu(self):
255255
]
256256
devices = ["cpu", "cuda"] if TEST_CUDA else ["cpu"]
257257
for device in devices:
258-
# Only test the non-in-place version relu quantized cuda,
259-
# will remove this when creating in-place version relu quantized cuda.
260-
if device == 'cuda':
261-
relu_test_configs = [
262-
{
263-
'quantized_fn': [
264-
torch.relu,
265-
torch.nn.functional.relu,
266-
],
267-
'reference_fn': torch.nn.functional.relu
268-
},
269-
]
270258
shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4))
271259
dtypes = (torch.quint8, torch.qint8)
272260
scales = (0.05, 0.1)

0 commit comments

Comments
 (0)