Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve fake_dequantize_op. #12877

Merged
merged 2 commits into from
Aug 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions paddle/fluid/operators/fake_dequantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,32 @@ limitations under the License. */
namespace paddle {
namespace operators {

template <typename T>
struct DequantizeFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor* scale,
T max_range, framework::Tensor* out) {
auto in_e = framework::EigenVector<T>::Flatten(*in);
const T* scale_factor = scale->data<T>();
auto out_e = framework::EigenVector<T>::Flatten(*out);

auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = (scale_factor[0] / max_range) * in_e;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用把max_range转成T类型么?python单测貌似没加float64的case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

更新了,也增加了float64的单测。

}
};

template struct DequantizeFunctor<platform::CPUDeviceContext, float>;
template struct DequantizeFunctor<platform::CPUDeviceContext, double>;

class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
public:
FakeDequantizeMaxAbsOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
FakeDequantizeMaxAbsOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}

void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeDequantizeMaxAbsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
Expand All @@ -42,21 +59,17 @@ class FakeDequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X",
"(Tensor) The input with float-32/64 type is the "
"low precision tensor.");
AddInput("Scale", "(float) The scale in quantization stage.");
AddOutput("Out",
"(Tensor) The output is the dequantized high "
"precision tensor.");
AddAttr<int>("num_bits",
"(int) `num_bits` is the quantization level bits, "
"such as 2, 5, 8.");
AddAttr<float>("scale",
"(float) The maximum absolute value of low precision tensor."
"It is usually calculated by the fake_quantize_max_abs_op.");
AddAttr<float>("max_range", "(float) The max range in quantization stage.");
AddComment(R"DOC(
FakeDequantizeMaxAbsOp operator.

This calculation is an opposite operation of FakeQuantizeMaxAbsOp:

$$Out = \frac{scale*X}{2^{num_bits} - 1}$$
$$Out = \frac{scale*X}{ max_range }$$

)DOC");
}
Expand Down
36 changes: 36 additions & 0 deletions paddle/fluid/operators/fake_dequantize_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,42 @@ limitations under the License. */

#include "paddle/fluid/operators/fake_dequantize_op.h"

namespace paddle {
namespace operators {

template <typename T>
__global__ void KeDequantize(const T* in, const T* scale, T max_range, int num,
T* out) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < num) {
out[idx] = in[idx] * scale[0] / max_range;
}
}

template <typename T>
struct DequantizeFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor* scale,
T max_range, framework::Tensor* out) {
const T* in_data = in->data<T>();
const T* scale_factor = scale->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());

int num = in->numel();
int block = 512;
int grid = (num + block - 1) / block;

KeDequantize<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, out_data);
}
};

template struct DequantizeFunctor<platform::CUDADeviceContext, float>;
template struct DequantizeFunctor<platform::CUDADeviceContext, double>;

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs,
Expand Down
23 changes: 15 additions & 8 deletions paddle/fluid/operators/fake_dequantize_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,29 @@ limitations under the License. */

namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
struct DequantizeFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
const framework::Tensor* scale, T max_range,
framework::Tensor* out);
};

template <typename DeviceContext, typename T>
class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::Tensor>("X");
auto* scale = ctx.Input<framework::Tensor>("Scale");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(in->place());

int num_bits = ctx.Attr<int>("num_bits");
T scale = static_cast<T>(ctx.Attr<float>("scale"));
int range = std::pow(2, num_bits) - 1;
float max_range = ctx.Attr<float>("max_range");

auto& dev_ctx = ctx.template device_context<DeviceContext>();
out->mutable_data<T>(dev_ctx.GetPlace());

auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
eigen_out.device(dev) = (scale / range) * eigen_in;
DequantizeFunctor<DeviceContext, T>()(dev_ctx, in, scale,
static_cast<T>(max_range), out);
}
};

Expand Down
33 changes: 21 additions & 12 deletions python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,50 @@
from op_test import OpTest


def quantize_max_abs(x, num_bits):
range = math.pow(2, num_bits) - 1
def quantize_max_abs(x, max_range):
scale = np.max(np.abs(x).flatten())
y = np.round(x / scale * range)
y = np.round(x / scale * max_range)
return y, scale


def dequantize_max_abs(x, num_bits, scale):
range = math.pow(2, num_bits) - 1
y = (scale / range) * x
def dequantize_max_abs(x, scale, max_range):
y = (scale / max_range) * x
return y


class TestFakeDequantizeMaxAbsOp(OpTest):
def set_args(self):
self.num_bits = 8
self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "float32"

def setUp(self):
self.set_args()
self.op_type = "fake_dequantize_max_abs"
x = np.random.randn(31, 65).astype("float32")
yq, scale = quantize_max_abs(x, self.num_bits)
ydq = dequantize_max_abs(yq, self.num_bits, scale)
x = np.random.randn(31, 65).astype(self.data_type)
yq, scale = quantize_max_abs(x, self.max_range)
ydq = dequantize_max_abs(yq, scale, self.max_range)

self.inputs = {'X': yq}
self.attrs = {'num_bits': self.num_bits, 'scale': float(scale)}
self.inputs = {'X': yq, 'Scale': np.array(scale).astype(self.data_type)}
self.attrs = {'max_range': self.max_range}
self.outputs = {'Out': ydq}

def test_check_output(self):
self.check_output()


class TestFakeDequantizeMaxAbsOp5Bits(OpTest):
class TestFakeDequantizeMaxAbsOpDouble(TestFakeDequantizeMaxAbsOp):
def set_args(self):
self.num_bits = 8
self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "float64"


class TestFakeDequantizeMaxAbsOp5Bits(TestFakeDequantizeMaxAbsOp):
def set_args(self):
self.num_bits = 5
self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "float32"


if __name__ == "__main__":
Expand Down