forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
/
AffineQuantizer.h
130 lines (111 loc) · 3.57 KB
/
AffineQuantizer.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/quantized/AffineQuantizerBase.h>
namespace at {
namespace native {
Tensor& quantize_tensor_per_tensor_affine(
const Tensor& rtensor,
Tensor& qtensor,
double scale,
int64_t zero_point);
Tensor& quantize_tensor_per_channel_affine(
const Tensor& rtensor,
Tensor& qtensor,
Tensor scales,
Tensor zero_points,
int64_t axis);
Tensor& quantize_tensor_per_channel_float_qparams(
const Tensor& rtensor,
Tensor& qtensor,
Tensor scales,
Tensor zero_points,
int64_t axis);
Tensor& dequantize_tensor_per_tensor_affine(
const Tensor& qtensor,
Tensor& rtensor,
double scale,
int64_t zero_point);
Tensor& dequantize_tensor_per_channel_affine(
const Tensor& qtensor,
Tensor& rtensor,
Tensor scales,
Tensor zero_points,
int64_t axis);
Tensor& dequantize_tensor_per_channel_float_qparams(
const Tensor& qtensor,
Tensor& rtensor,
Tensor scales,
Tensor zero_points,
int64_t axis);
using quantize_tensor_per_tensor_affine_fn =
void (*)(const Tensor& rtensor, Tensor& qtensor, double scale, int64_t zero_point);
using quantize_tensor_per_channel_affine_fn = void (*)(
const Tensor& rtensor,
Tensor& qtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis);
using quantize_tensor_per_channel_float_qparams_fn = void (*)(
const Tensor& rtensor,
Tensor& qtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis);
using dequantize_tensor_per_tensor_affine_fn =
void (*)(const Tensor& qtensor, Tensor& rtensor, double scale, int64_t zero_point);
using dequantize_tensor_per_channel_affine_fn = void (*)(
const Tensor& qtensor,
Tensor& rtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis);
using dequantize_tensor_per_channel_float_qparams_fn = void (*)(
const Tensor& qtensor,
Tensor& rtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis);
using quantize_tensor_per_tensor_affine_sub_byte_fn =
void (*)(const Tensor& rtensor, Tensor& qtensor, float scale, float zero_point);
using dequantize_tensor_per_tensor_affine_sub_byte_fn =
void (*)(const Tensor& qtensor, Tensor& rtensor, float scale, float zero_point);
DECLARE_DISPATCH(
quantize_tensor_per_tensor_affine_fn,
quantize_tensor_per_tensor_affine_stub);
DECLARE_DISPATCH(
quantize_tensor_per_channel_affine_fn,
quantize_tensor_per_channel_affine_stub);
DECLARE_DISPATCH(
quantize_tensor_per_channel_float_qparams_fn,
quantize_tensor_per_channel_float_qparams_stub);
DECLARE_DISPATCH(
dequantize_tensor_per_tensor_affine_fn,
dequantize_tensor_per_tensor_affine_stub);
DECLARE_DISPATCH(
dequantize_tensor_per_channel_affine_fn,
dequantize_tensor_per_channel_affine_stub);
DECLARE_DISPATCH(
dequantize_tensor_per_channel_float_qparams_fn,
dequantize_tensor_per_channel_float_qparams_stub);
DECLARE_DISPATCH(
quantize_tensor_per_tensor_affine_sub_byte_fn,
quantize_tensor_per_tensor_affine_sub_byte_stub);
DECLARE_DISPATCH(
dequantize_tensor_per_tensor_affine_sub_byte_fn,
dequantize_tensor_per_tensor_affine_sub_byte_stub);
template <typename T>
TORCH_API Tensor quantize_tensor(
Tensor rtensor,
Tensor qtensor,
double scale,
int64_t zero_point);
template <typename T>
TORCH_API Tensor dequantize_tensor(
Tensor qtensor,
Tensor rtensor,
double scale,
int64_t zero_point);
} // namespace native
} // namespace at