Skip to content

Commit

Permalink
[luci] Introduce MinimumMSE quantization algorithm
Browse files Browse the repository at this point in the history
This commit introduces MinimumMSE quantization algorithm.

ONE-DCO-1.0-Signed-off-by: Vyacheslav Bazhenov <slavikmipt@gmail.com>
  • Loading branch information
Vyacheslav Bazhenov committed Jan 20, 2025
1 parent cd2f583 commit 28f6b16
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 14 deletions.
19 changes: 19 additions & 0 deletions compiler/luci/pass/include/luci/Pass/QuantizationParameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,25 @@ struct LayerInfo
QuantizationGranularity granularity;
};

enum struct QuantizationAlgorithmType
{
Base = 0,
MinimumMSE = 1
};

struct QuantizationAlgorithmParams
{
QuantizationAlgorithmType type = QuantizationAlgorithmType::Base;

// Params of Golden-section search algorithm
// Number of iterations of Golden-section search
size_t iterations_num = 100;

// scaling_factor_max = scaling_factor_base * (1 + range)
// scaling_factor_min = scaling_factor_base * (1 - range)
float range = 0.1;
};

} // namespace luci

#endif // __LUCI_QUANTIZATION_PARAMETERS_H__
4 changes: 3 additions & 1 deletion compiler/luci/pass/include/luci/Pass/QuantizeWeightsPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class QuantizeWeightsPass : public logo::Pass
loco::DataType input_model_dtype = loco::DataType::Unknown;
loco::DataType output_model_dtype = loco::DataType::Unknown;
QuantizationGranularity granularity = QuantizationGranularity::ChannelWise;
QuantizationAlgorithmParams algorithm_params;
};

public:
Expand All @@ -47,13 +48,14 @@ class QuantizeWeightsPass : public logo::Pass

public:
QuantizeWeightsPass(loco::DataType input_model_dtype, loco::DataType output_model_dtype,
QuantizationGranularity granularity)
QuantizationGranularity granularity, QuantizationAlgorithmParams algorithm_params)
{
_ctx = std::make_unique<Context>();
{
_ctx->input_model_dtype = input_model_dtype;
_ctx->output_model_dtype = output_model_dtype;
_ctx->granularity = granularity;
_ctx->algorithm_params = algorithm_params;
}
}
virtual const char *name(void) const { return "luci::QuantizeWeightsPass"; }
Expand Down
178 changes: 172 additions & 6 deletions compiler/luci/pass/src/QuantizeWeightsOnly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,142 @@ void iterate_per_channel(CircleConst *node, int32_t &channel_dim_index, IterFunc
}
}

template <loco::DataType out_type>
void sym_wquant_per_channel_minimum_mse(CircleConst *node, std::vector<float> &min,
std::vector<float> &max, std::vector<float> &scaling_factor,
std::vector<float> &nudged_min,
std::vector<float> &nudged_max, int32_t &channel_dim_index,
const QuantizationAlgorithmParams &params)
{
assert(node->dtype() == loco::DataType::FLOAT32);
assert(out_type == loco::DataType::S4 || out_type == loco::DataType::S8 ||
out_type == loco::DataType::S16);

const auto kPhi = 1.618033988749894848204586834365638118; // Golden ratio
const auto kSearchIterations = params.iterations_num;
const auto kRangeCoefficient = params.range;

const int32_t kMaxScale = max_for_sym_quant(out_type);
const int32_t kMinScale = -kMaxScale;

uint32_t size = node->size<loco::DataType::FLOAT32>();
std::vector<int32_t> quantized_values(size);

for (size_t i = 0; i < min.size(); ++i)
{
compute_sym_scale(min[i], max[i], scaling_factor[i], nudged_min[i], nudged_max[i], out_type);
}

auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) {
int channel_idx = indices[channel_dim_index];
const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data;
data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data;
quantized_values[cal_offset(dimension, indices)] =
static_cast<int32_t>(std::round(data * scaling_factor_inv));
};
std::vector<float> max_scale(min.size());
for (size_t i = 0; i < min.size(); ++i)
{
max_scale[i] = std::max(std::fabs(min[i]), std::fabs(max[i]));
}
std::vector<double> channel_mse(min.size());
std::vector<double> channel_min_mse(min.size(), std::numeric_limits<double>::max());

auto calculate_mse = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) {
int channel_idx = indices[channel_dim_index];
auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data;
data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data;
double diff =
data - quantized_values[cal_offset(dimension, indices)] * scaling_factor[channel_idx];
channel_mse[channel_idx] += diff * diff;
};

std::vector<float> scaling_factor_base = scaling_factor;
std::vector<std::pair<float, float>> golden_start_end(min.size());

for (size_t i = 0; i < max_scale.size(); ++i)
{
golden_start_end[i].first = scaling_factor_base[i] * (1.0 - kRangeCoefficient);
golden_start_end[i].second = scaling_factor_base[i] * (1.0 + kRangeCoefficient);
}

for (size_t i = 0; i < kSearchIterations; ++i)
{
for (size_t j = 0; j < scaling_factor.size(); ++j)
{
scaling_factor[j] = golden_start_end[j].second -
(golden_start_end[j].second - golden_start_end[j].first) / kPhi;
}
for (auto &val : channel_mse)
{
val = 0;
}
iterate_per_channel(node, channel_dim_index, quantize);
iterate_per_channel(node, channel_dim_index, calculate_mse);
auto channel_mse_x1 = channel_mse;

for (size_t j = 0; j < scaling_factor.size(); ++j)
{
scaling_factor[j] =
golden_start_end[j].first + (golden_start_end[j].second - golden_start_end[j].first) / kPhi;
}
for (auto &val : channel_mse)
{
val = 0;
}
iterate_per_channel(node, channel_dim_index, quantize);
iterate_per_channel(node, channel_dim_index, calculate_mse);
auto channel_mse_x2 = channel_mse;

for (size_t k = 0; k < channel_mse_x1.size(); ++k)
{
if (channel_mse_x1[k] > channel_mse_x2[k])
{
golden_start_end[k].first = golden_start_end[k].second -
(golden_start_end[k].second - golden_start_end[k].first) / kPhi;
}
else
{
golden_start_end[k].second =
golden_start_end[k].first +
(golden_start_end[k].second - golden_start_end[k].first) / kPhi;
}
}
}
for (size_t i = 0; i < golden_start_end.size(); ++i)
{
scaling_factor[i] = (golden_start_end[i].first + golden_start_end[i].second) / 2;
}
iterate_per_channel(node, channel_dim_index, quantize);
iterate_per_channel(node, channel_dim_index, calculate_mse);
auto channel_mse_opt = channel_mse;

scaling_factor = scaling_factor_base;
iterate_per_channel(node, channel_dim_index, quantize);
iterate_per_channel(node, channel_dim_index, calculate_mse);
auto channel_mse_base = channel_mse;

// Checking if found scale is better than base
for (size_t i = 0; i < channel_mse_base.size(); ++i)
{
if (channel_mse_opt[i] < channel_mse_base[i])
scaling_factor[i] = (golden_start_end[i].first + golden_start_end[i].second) / 2;
else
channel_mse_opt[i] = channel_mse_base[i];
}
iterate_per_channel(node, channel_dim_index, quantize);

node->dtype(out_type); // change the type of tensor
node->size<out_type>(size); // resize tensor
for (uint32_t i = 0; i < size; ++i)
{
node->at<out_type>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
}
}

// TODO Reduce duplicate code with QuantizeDequantizeWeights
template <loco::DataType out_type>
void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max,
Expand Down Expand Up @@ -166,18 +302,48 @@ void QuantizeWeightsOnly::quantize_weights(luci::CircleConst *weights)

if (output_type == loco::DataType::S4)
{
sym_wquant_per_channel<loco::DataType::S4>(weights, min, max, scaling_factor, nudged_min,
nudged_max, channel_dim_index);
switch (algorithm_params.type)
{
case luci::QuantizationAlgorithmType::MinimumMSE:
sym_wquant_per_channel_minimum_mse<loco::DataType::S4>(
weights, min, max, scaling_factor, nudged_min, nudged_max, channel_dim_index,
algorithm_params);
break;
default:
sym_wquant_per_channel<loco::DataType::S4>(weights, min, max, scaling_factor,
nudged_min, nudged_max, channel_dim_index);
break;
}
}
else if (output_type == loco::DataType::S8)
{
sym_wquant_per_channel<loco::DataType::S8>(weights, min, max, scaling_factor, nudged_min,
nudged_max, channel_dim_index);
switch (algorithm_params.type)
{
case luci::QuantizationAlgorithmType::MinimumMSE:
sym_wquant_per_channel_minimum_mse<loco::DataType::S8>(
weights, min, max, scaling_factor, nudged_min, nudged_max, channel_dim_index,
algorithm_params);
break;
default:
sym_wquant_per_channel<loco::DataType::S8>(weights, min, max, scaling_factor,
nudged_min, nudged_max, channel_dim_index);
break;
}
}
else if (output_type == loco::DataType::S16)
{
sym_wquant_per_channel<loco::DataType::S16>(weights, min, max, scaling_factor, nudged_min,
nudged_max, channel_dim_index);
switch (algorithm_params.type)
{
case luci::QuantizationAlgorithmType::MinimumMSE:
sym_wquant_per_channel_minimum_mse<loco::DataType::S16>(
weights, min, max, scaling_factor, nudged_min, nudged_max, channel_dim_index,
algorithm_params);
break;
default:
sym_wquant_per_channel<loco::DataType::S16>(weights, min, max, scaling_factor,
nudged_min, nudged_max, channel_dim_index);
break;
}
}
else
{
Expand Down
6 changes: 4 additions & 2 deletions compiler/luci/pass/src/QuantizeWeightsOnly.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@ namespace luci
*/
struct QuantizeWeightsOnly final : public luci::CircleNodeMutableVisitor<void>
{
QuantizeWeightsOnly(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
: input_type(input), output_type(output), granularity(gr)
QuantizeWeightsOnly(loco::DataType input, loco::DataType output, QuantizationGranularity gr,
QuantizationAlgorithmParams alg_par)
: input_type(input), output_type(output), granularity(gr), algorithm_params(alg_par)
{
}

loco::DataType input_type;
loco::DataType output_type;
QuantizationGranularity granularity;
QuantizationAlgorithmParams algorithm_params;

private:
void quantize_weights(luci::CircleConst *weights);
Expand Down
3 changes: 2 additions & 1 deletion compiler/luci/pass/src/QuantizeWeightsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ bool QuantizeWeightsPass::run(loco::Graph *g)
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
QuantizeWeightsOnly qw(_ctx->input_model_dtype, _ctx->output_model_dtype, _ctx->granularity);
QuantizeWeightsOnly qw(_ctx->input_model_dtype, _ctx->output_model_dtype, _ctx->granularity,
_ctx->algorithm_params);
circle_node->accept(&qw);
}

Expand Down
35 changes: 31 additions & 4 deletions compiler/luci/pass/src/QuantizeWeightsPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,46 +78,73 @@ struct QuantizeWeightsPassTest : public ::testing::Test
output->dtype(loco::DataType::FLOAT32);
output->name("output");
}
virtual void SetUp() { MakeGraph(); }
virtual void SetUp() override { MakeGraph(); }
loco::Graph _g;
};

} // namespace

TEST_F(QuantizeWeightsPassTest, name)
{
luci::QuantizationAlgorithmParams params;
params.type = luci::QuantizationAlgorithmType::Base;
luci::QuantizeWeightsPass pass(loco::DataType::FLOAT32, loco::DataType::S8,
luci::QuantizationGranularity::ChannelWise);
luci::QuantizationGranularity::ChannelWise, params);
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}

TEST_F(QuantizeWeightsPassTest, name_ctx)
{
luci::QuantizationAlgorithmParams params;
params.type = luci::QuantizationAlgorithmType::Base;
auto ctx = std::make_unique<luci::QuantizeWeightsPass::Context>();
{
ctx->input_model_dtype = loco::DataType::FLOAT32;
ctx->output_model_dtype = loco::DataType::S8;
ctx->granularity = luci::QuantizationGranularity::ChannelWise;
ctx->algorithm_params = params;
}

luci::QuantizeWeightsPass pass(std::move(ctx));
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}

TEST_F(QuantizeWeightsPassTest, run_minimum_mse_s8)
{
luci::QuantizationAlgorithmParams params;
params.type = luci::QuantizationAlgorithmType::MinimumMSE;
luci::QuantizeWeightsPass pass(loco::DataType::FLOAT32, loco::DataType::S8,
luci::QuantizationGranularity::ChannelWise, params);
pass.run(&_g);
}

TEST_F(QuantizeWeightsPassTest, run_input_U8_mse_NEG)
{
luci::QuantizationAlgorithmParams params;
params.type = luci::QuantizationAlgorithmType::MinimumMSE;
luci::QuantizeWeightsPass pass(loco::DataType::U8, loco::DataType::S8,
luci::QuantizationGranularity::ChannelWise, params);
EXPECT_THROW(pass.run(&_g), std::runtime_error);
}

TEST_F(QuantizeWeightsPassTest, run_input_U8_NEG)
{
loco::Graph g;
luci::QuantizationAlgorithmParams params;
params.type = luci::QuantizationAlgorithmType::Base;
luci::QuantizeWeightsPass pass(loco::DataType::U8, loco::DataType::S8,
luci::QuantizationGranularity::ChannelWise);
luci::QuantizationGranularity::ChannelWise, params);
EXPECT_THROW(pass.run(&_g), std::runtime_error);
}

TEST_F(QuantizeWeightsPassTest, run_output_f32_NEG)
{
loco::Graph g;
luci::QuantizationAlgorithmParams params;
params.type = luci::QuantizationAlgorithmType::Base;
luci::QuantizeWeightsPass pass(loco::DataType::FLOAT32, loco::DataType::FLOAT32,
luci::QuantizationGranularity::ChannelWise);
luci::QuantizationGranularity::ChannelWise, params);
EXPECT_THROW(pass.run(&_g), std::runtime_error);
}

0 comments on commit 28f6b16

Please sign in to comment.