Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[one-optimize] Fuse Mul with FullyConnected layer #13528

Closed
wants to merge 48 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
28d8089
[one-optimize] Fuse Mul with FullyConnected layer
jiwaszki Jul 26, 2024
87774c5
Move mul_with_fully_connected pass after the mul_with_div
jiwaszki Jul 29, 2024
4c7f5a9
Remove weights constant check
jiwaszki Jul 29, 2024
cfcb68b
Change order of updating the nodes, more consuming one is now later
jiwaszki Jul 29, 2024
e8b06b5
Fix values updating and add luci tests
jiwaszki Aug 1, 2024
3bf2649
Fix codestyle
jiwaszki Aug 2, 2024
d386541
Rename pass
jiwaszki Aug 5, 2024
4180734
Add luci tests with models
jiwaszki Aug 5, 2024
761d303
Fix scalar vs multi-dim case
jiwaszki Aug 5, 2024
fa4733b
Separate bias and weights updating, remove checks
jiwaszki Aug 6, 2024
40dcacf
[luci/pass] Introduce FuseMulWithFullyConnectedPass
jiwaszki Aug 6, 2024
b717234
[one-cmds] Add an option for FuseMulWithFullyConnectedPass
jiwaszki Aug 6, 2024
a568d25
[circle2circle] Dredd test for FuseMulWithFullyConnectedPass
jiwaszki Aug 7, 2024
d3246e3
[luci/pass] Value test for FuseMulWithFullyConnectedPass
jiwaszki Aug 7, 2024
f661561
Change constness of args, move tests and move FuseMulWithFC after Fus…
jiwaszki Aug 7, 2024
85d9783
Fix codestyle
jiwaszki Aug 7, 2024
51dd43c
Fix order of cmds
jiwaszki Aug 7, 2024
e3b354e
Remove default arguments
jiwaszki Aug 8, 2024
ffc36e9
Remove default args
jiwaszki Aug 8, 2024
835126a
Merge remote-tracking branch 'upstream/master' into jiwaszki/fuse_mul_fc
jiwaszki Aug 9, 2024
31e25ed
Refactor solution and apply comments
jiwaszki Aug 9, 2024
396d733
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki Aug 9, 2024
d5ec1d8
Merge branch 'jiwaszki/fuse_mul_fc_one_cmds' into jiwaszki/fuse_mul_fc
jiwaszki Aug 9, 2024
8b17f47
Add handling of no bias case to pass
jiwaszki Aug 9, 2024
8d90e50
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki Aug 9, 2024
715cdf7
Remove random newline
jiwaszki Aug 9, 2024
9e22b26
Apply comments, refactor tests and add proper handling of OUTPUTEXCLUDE
jiwaszki Aug 12, 2024
62a09a0
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki Aug 12, 2024
1b6c71f
Resolve one-cmds duplication
jiwaszki Aug 12, 2024
8977ef9
Handle rank 0 and 1
jiwaszki Aug 12, 2024
dbed1b9
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki Aug 12, 2024
53aa943
Add new testcase
jiwaszki Aug 12, 2024
0c2bb71
Add new testcase
jiwaszki Aug 12, 2024
e3ff517
[res/tfl_recipes] Add new Net_FullyConnected_Mul
jiwaszki Aug 12, 2024
d6e8b4a
Merge branch 'jiwaszki/fuse_mul_fc_new_tfl_recipes' into jiwaszki/fus…
jiwaszki Aug 12, 2024
b4ebd44
Merge branch 'jiwaszki/fuse_mul_fc_c2c_dredd' into jiwaszki/fuse_mul_fc
jiwaszki Aug 12, 2024
cddd353
Merge branch 'jiwaszki/fuse_mul_fc_luci_test' into jiwaszki/fuse_mul_fc
jiwaszki Aug 12, 2024
678869c
Change name of operand from B to scale
jiwaszki Aug 13, 2024
af3119d
Merge branch 'jiwaszki/fuse_mul_fc_new_tfl_recipes' into jiwaszki/fus…
jiwaszki Aug 13, 2024
b085181
Update names from scalar to single element
jiwaszki Aug 13, 2024
1aa79cc
Update tests
jiwaszki Aug 13, 2024
f02cb88
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki Aug 13, 2024
1bb278d
Fix codestyle
jiwaszki Aug 13, 2024
27dec03
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki Aug 13, 2024
79a2213
Search from mul, update tests
jiwaszki Aug 14, 2024
7ea759a
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki Aug 14, 2024
550e798
Annotate requirement of one successor and refactor checks
jiwaszki Aug 19, 2024
bda96d8
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki Aug 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions compiler/circle2circle-dredd-recipe-test/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ Add(Net_DwConv_BN_000 PASS fuse_batchnorm_with_dwconv)
Add(Net_DwConv_BN_001 PASS fuse_batchnorm_with_dwconv)
Add(Net_FC_Gelu_FC_000 PASS replace_with_fc_gelu_fc)
Add(Net_FullyConnected_Add_000 PASS fold_fully_connected)
Add(Net_FullyConnected_Mul_000 PASS fuse_mul_with_fullyconnected)
Add(Net_FullyConnected_Mul_001 PASS fuse_mul_with_fullyconnected)
Add(Net_FullyConnected_Mul_002 PASS fuse_mul_with_fullyconnected)
Add(Net_FullyConnected_Mul_003 PASS fuse_mul_with_fullyconnected)
Add(Net_Gelu_000 PASS fuse_gelu)
Add(Net_Gelu_001 PASS fuse_gelu)
Add(Net_Horizontal_FullyConnected_Add_000 PASS fuse_horizontal_fc_layers)
Expand Down
4 changes: 4 additions & 0 deletions compiler/circle2circle/src/Circle2Circle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ int entry(int argc, char **argv)
"This will fuse Mul operation with a preceding Conv if possible.");
add_switch(arser, "--fuse_mul_with_div",
"This will fuse Mul operation with a Div operation whose numerator is const.");
add_switch(arser, "--fuse_mul_with_fullyconnected",
"This will fuse Mul operator with a preceding FullyConnected operator.");
add_switch(arser, "--fuse_slice_with_tconv",
"This will fuse Slice operation with a preceding TConv if possible.");
add_switch(arser, "--fuse_transpose_with_mean",
Expand Down Expand Up @@ -326,6 +328,8 @@ int entry(int argc, char **argv)
options->enable(Algorithms::FuseMulWithConv);
if (arser.get<bool>("--fuse_mul_with_div"))
options->enable(Algorithms::FuseMulWithDiv);
if (arser.get<bool>("--fuse_mul_with_fullyconnected"))
options->enable(Algorithms::FuseMulWithFullyConnected);
if (arser.get<bool>("--make_batchnorm_gamma_positive"))
options->enable(Algorithms::MakeBatchNormGammaPositive);
if (arser.get<bool>("--fuse_preactivation_batchnorm"))
Expand Down
4 changes: 4 additions & 0 deletions compiler/luci-pass-value-py-test/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ eval(Net_Dequantize_Add_000 fold_dequantize)
eval(Net_DwConv_BN_000 fuse_batchnorm_with_dwconv)
eval(Net_DwConv_BN_001 fuse_batchnorm_with_dwconv)
eval(Net_FullyConnected_Add_000 fold_fully_connected)
eval(Net_FullyConnected_Mul_000 fuse_mul_with_fullyconnected)
eval(Net_FullyConnected_Mul_001 fuse_mul_with_fullyconnected)
eval(Net_FullyConnected_Mul_002 fuse_mul_with_fullyconnected)
eval(Net_FullyConnected_Mul_003 fuse_mul_with_fullyconnected)
eval(Net_Horizontal_FullyConnected_Add_000 fuse_horizontal_fc_layers)
eval(Net_InstanceNorm_001 fuse_instnorm)
eval(Net_InstanceNorm_002 fuse_instnorm)
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/pass/include/luci/CircleOptimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class CircleOptimizer final
FuseMeanWithMean,
FuseMulWithConv,
FuseMulWithDiv,
FuseMulWithFullyConnected,
FuseTransposeWithMean,
ResolveCustomOpAdd,
ResolveCustomOpBatchMatMul,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__
#define __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__

#include <logo/Pass.h>

namespace luci
{

/**
* @brief Class to fuse Mul into CircleFullyConnected
*/
struct FuseMulWithFullyConnectedPass final : public logo::Pass
{
const char *name(void) const final { return "luci::FuseMulWithFullyConnectedPass"; }

bool run(loco::Graph *g) final;
};

} // namespace luci

#endif // __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__
9 changes: 9 additions & 0 deletions compiler/luci/pass/src/CircleOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include "luci/Pass/FuseMeanWithMeanPass.h"
#include "luci/Pass/FuseMulWithConvPass.h"
#include "luci/Pass/FuseMulWithDivPass.h"
#include "luci/Pass/FuseMulWithFullyConnectedPass.h"
#include "luci/Pass/FusePreActivationBatchNormPass.h"
#include "luci/Pass/FusePReluPass.h"
#include "luci/Pass/FuseGeluPass.h"
Expand Down Expand Up @@ -278,6 +279,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());

if (_options->query(Options::Algorithm::FuseMulWithFullyConnected))
{
phase.emplace_back(std::make_unique<FuseMulWithFullyConnectedPass>());
}
if (_options->query(Options::Algorithm::CommonSubExpressionElimination))
{
phase.emplace_back(std::make_unique<luci::CommonSubExpressionEliminationPass>());
Expand Down Expand Up @@ -310,6 +315,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<FuseMulWithDivPass>());
}
if (_options->query(Options::Algorithm::FuseMulWithFullyConnected))
{
phase.emplace_back(std::make_unique<FuseMulWithFullyConnectedPass>());
}
if (_options->query(Options::Algorithm::ResolveCustomOpMaxPoolWithArgmax))
{
phase.emplace_back(std::make_unique<luci::ResolveCustomOpMaxPoolWithArgmaxPass>());
Expand Down
239 changes: 239 additions & 0 deletions compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plz add FuseMulWithFullyConnectedPass.test.cpp

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the unit test, you should add count(positive) <= count(negative) where negative method name ends with _NEG

* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Pass/FuseMulWithFullyConnectedPass.h"

#include "helpers/NodeFiller.h"

#include <luci/IR/CircleNodes.h>
#include <luci/Service/Nodes/CircleConst.h>
#include <luci/Profile/CircleNodeOrigin.h>

namespace
{

#define RETURN_FALSE_UNLESS(cond) \
if (not(cond)) \
return false;

inline bool is_single_element(const luci::CircleConst *node)
{
return ((node->rank() == 1 || node->rank() == 0) && node->size<loco::DataType::FLOAT32>() == 1);
}

inline void update_with_single_element(luci::CircleConst *fused_node,
const luci::CircleConst *multiplication)
{
for (uint32_t i = 0; i < fused_node->size<loco::DataType::FLOAT32>(); i++)
{
fused_node->at<loco::DataType::FLOAT32>(i) *= multiplication->at<loco::DataType::FLOAT32>(0);
}
}

luci::CircleConst *gen_fused_weights(luci::CircleConst *weights,
const luci::CircleConst *multiplication)
{
auto fused_weights = luci::clone(weights);
// Single element multiplication:
if (is_single_element(multiplication))
{
update_with_single_element(fused_weights, multiplication);
}
// N-size multiplication:
else
{
// Go along channels, multiplication size is ensured to be compatible with channels.
auto count = fused_weights->dim(0).value();
auto size = fused_weights->dim(fused_weights->rank() - 1).value();
float val;
for (uint32_t c = 0; c < count; c++)
{
val = multiplication->at<loco::DataType::FLOAT32>(c);
for (uint32_t i = 0; i < size; i++)
{
fused_weights->at<loco::DataType::FLOAT32>(c * size + i) *= val;
}
}
}
return fused_weights;
}

luci::CircleConst *gen_fused_bias(luci::CircleConst *bias, const luci::CircleConst *multiplication)
{
auto fused_bias = luci::clone(bias);
// Single element multiplication:
if (is_single_element(multiplication))
{
update_with_single_element(fused_bias, multiplication);
}
// N-size multiplication:
else
{
// Go along channels, multiplication size is ensured to be compatible with channels.
for (uint32_t i = 0; i < fused_bias->size<loco::DataType::FLOAT32>(); i++)
{
fused_bias->at<loco::DataType::FLOAT32>(i) *= multiplication->at<loco::DataType::FLOAT32>(i);
}
}
return fused_bias;
}

/**
* Fuse Mul to FullyConnected if the multiplied value is a channel(last dimension)-wise constant
*
* BEFORE
* |
* [CircleFullyConnected]
* |
* [CircleMul]
* |
*
* AFTER
* |
* [CircleFullyConnected] [CircleMul] (dead)
* |
*
*/
bool fuse_mul_with_fc(luci::CircleMul *mul)
{
// Sanity check:
RETURN_FALSE_UNLESS(mul);
// Allow Mul node only with FLOAT32 data type:
RETURN_FALSE_UNLESS(mul->dtype() == loco::DataType::FLOAT32);
// Check if any FC node connects to Mul.
// Find the pattern of Mul(FC, CircleConst):
luci::CircleFullyConnected *fc = nullptr;
luci::CircleConst *multiplication = nullptr;
RETURN_FALSE_UNLESS(luci::fill(&fc, &multiplication).with_commutative_args_of(mul));
/**
* Make sure that FullyConnected has only one successor.
*
* If the FullyConnected output is connected to more nodes,
* this pass will replace node with new fused FullyConnected.
* Thus pass success will only introduce extra FullyConnected
* without reducing overall number of nodes.
* Which tends to increase model's size and degrades model's performance.
* Thus one successor is required to benefit from this pass.
*
* Example graph that illustrates the described scenario:
*
* BEFORE
* |
* [CircleFullyConnected]
* |
* +-------+----------------+
* | |
* | |
* [Other Node] [CircleMul]
* | |
*
* AFTER
* |
* [CircleFullyConnected]
* |
* +-------+-----------------------+
* | |
* | |
* [Other Node] [New CircleFullyConnected Fused with Mul]
* | |
*
*/
RETURN_FALSE_UNLESS(loco::succs(fc).size() == 1);
// Allow only FLOAT32 data type:
RETURN_FALSE_UNLESS(fc->dtype() == loco::DataType::FLOAT32);
// Allow only without activation functions as values are going to
// be multiplied before activation function.
RETURN_FALSE_UNLESS(fc->fusedActivationFunction() == luci::FusedActFunc::NONE);
// Check for weights being Constant:
auto weights = dynamic_cast<luci::CircleConst *>(fc->weights());
RETURN_FALSE_UNLESS(weights);
// Get rank of multiplication:
auto rank = multiplication->rank();
// Check that all dimensions are ones, checks broadcast capabilites.
// Last dimesion of multiplication must be compatible with FC.
// N-D case (N>1):
if (multiplication->rank() > 1)
{
// Check channel-wise broadcasting:
for (uint32_t i = 0; i < rank - 1; i++)
RETURN_FALSE_UNLESS(multiplication->dim(i).value() == 1);
// Check the last dimesion of Mul is the same with the first dimension of FullyConnected
RETURN_FALSE_UNLESS(multiplication->dim(rank - 1) == weights->dim(0));
}
// 1-D or scalar case:
else if (multiplication->rank() == 1)
{
RETURN_FALSE_UNLESS(multiplication->size<loco::DataType::FLOAT32>() == 1 ||
multiplication->size<loco::DataType::FLOAT32>() == weights->dim(0));
}
else if (multiplication->rank() == 0)
{
RETURN_FALSE_UNLESS(multiplication->size<loco::DataType::FLOAT32>() == 1);
}

// Only supports:
// (1) constant bias
// (2) no bias
auto bias = loco::must_cast<luci::CircleNode *>(fc->bias());
if (bias->opcode() == luci::CircleOpcode::CIRCLECONST)
{
// Create new bias to be updated with values:
auto const_bias = dynamic_cast<luci::CircleConst *>(fc->bias());
RETURN_FALSE_UNLESS(const_bias)
RETURN_FALSE_UNLESS(const_bias->dtype() == loco::DataType::FLOAT32);
// Create new bias with updated values and replace:
auto fused_bias = gen_fused_bias(const_bias, multiplication);
fc->bias(fused_bias);
}
else if (bias->opcode() != luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE)
{
return false;
}

// Create new weights with updated values and replace:
auto fused_weights = gen_fused_weights(weights, multiplication);
fc->weights(fused_weights);

// Set origin and copy Activation Function if exisitng:
fc->fusedActivationFunction(mul->fusedActivationFunction());
luci::add_origin(fc, luci::get_origin(mul));

replace(mul).with(fc);

return true;
}

} // namespace

namespace luci
{

bool FuseMulWithFullyConnectedPass::run(loco::Graph *g)
{
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
if (auto mul = dynamic_cast<luci::CircleMul *>(node))
{
if (fuse_mul_with_fc(mul))
changed = true;
}
}

return changed;
}

} // namespace luci
Loading