-
Notifications
You must be signed in to change notification settings - Fork 159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[one-optimize] Fuse Mul with FullyConnected layer #13528
Closed
Closed
Changes from all commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
28d8089
[one-optimize] Fuse Mul with FullyConnected layer
jiwaszki 87774c5
Move mul_with_fully_connected pass after the mul_with_div
jiwaszki 4c7f5a9
Remove weights constant check
jiwaszki cfcb68b
Change order of updating the nodes, more consuming one is now later
jiwaszki e8b06b5
Fix values updating and add luci tests
jiwaszki 3bf2649
Fix codestyle
jiwaszki d386541
Rename pass
jiwaszki 4180734
Add luci tests with models
jiwaszki 761d303
Fix scalar vs multi-dim case
jiwaszki fa4733b
Separate bias and weights updating, remove checks
jiwaszki 40dcacf
[luci/pass] Introduce FuseMulWithFullyConnectedPass
jiwaszki b717234
[one-cmds] Add an option for FuseMulWithFullyConnectedPass
jiwaszki a568d25
[circle2circle] Dredd test for FuseMulWithFullyConnectedPass
jiwaszki d3246e3
[luci/pass] Value test for FuseMulWithFullyConnectedPass
jiwaszki f661561
Change constness of args, move tests and move FuseMulWithFC after Fus…
jiwaszki 85d9783
Fix codestyle
jiwaszki 51dd43c
Fix order of cmds
jiwaszki e3b354e
Remove default arguments
jiwaszki ffc36e9
Remove default args
jiwaszki 835126a
Merge remote-tracking branch 'upstream/master' into jiwaszki/fuse_mul_fc
jiwaszki 31e25ed
Refactor solution and apply comments
jiwaszki 396d733
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki d5ec1d8
Merge branch 'jiwaszki/fuse_mul_fc_one_cmds' into jiwaszki/fuse_mul_fc
jiwaszki 8b17f47
Add handling of no bias case to pass
jiwaszki 8d90e50
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki 715cdf7
Remove random newline
jiwaszki 9e22b26
Apply comments, refactor tests and add proper handling of OUTPUTEXCLUDE
jiwaszki 62a09a0
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki 1b6c71f
Resolve one-cmds duplication
jiwaszki 8977ef9
Handle rank 0 and 1
jiwaszki dbed1b9
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki 53aa943
Add new testcase
jiwaszki 0c2bb71
Add new testcase
jiwaszki e3ff517
[res/tfl_recipes] Add new Net_FullyConnected_Mul
jiwaszki d6e8b4a
Merge branch 'jiwaszki/fuse_mul_fc_new_tfl_recipes' into jiwaszki/fus…
jiwaszki b4ebd44
Merge branch 'jiwaszki/fuse_mul_fc_c2c_dredd' into jiwaszki/fuse_mul_fc
jiwaszki cddd353
Merge branch 'jiwaszki/fuse_mul_fc_luci_test' into jiwaszki/fuse_mul_fc
jiwaszki 678869c
Change name of operand from B to scale
jiwaszki af3119d
Merge branch 'jiwaszki/fuse_mul_fc_new_tfl_recipes' into jiwaszki/fus…
jiwaszki b085181
Update names from scalar to single element
jiwaszki 1aa79cc
Update tests
jiwaszki f02cb88
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki 1bb278d
Fix codestyle
jiwaszki 27dec03
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki 79a2213
Search from mul, update tests
jiwaszki 7ea759a
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki 550e798
Annotate requirement of one successor and refactor checks
jiwaszki bda96d8
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
jiwaszki File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
37 changes: 37 additions & 0 deletions
37
compiler/luci/pass/include/luci/Pass/FuseMulWithFullyConnectedPass.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__ | ||
#define __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__ | ||
|
||
#include <logo/Pass.h> | ||
|
||
namespace luci | ||
{ | ||
|
||
/** | ||
* @brief Class to fuse Mul into CircleFullyConnected | ||
*/ | ||
struct FuseMulWithFullyConnectedPass final : public logo::Pass | ||
{ | ||
const char *name(void) const final { return "luci::FuseMulWithFullyConnectedPass"; } | ||
|
||
bool run(loco::Graph *g) final; | ||
}; | ||
|
||
} // namespace luci | ||
|
||
#endif // __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
239 changes: 239 additions & 0 deletions
239
compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "luci/Pass/FuseMulWithFullyConnectedPass.h" | ||
|
||
#include "helpers/NodeFiller.h" | ||
|
||
#include <luci/IR/CircleNodes.h> | ||
#include <luci/Service/Nodes/CircleConst.h> | ||
#include <luci/Profile/CircleNodeOrigin.h> | ||
|
||
namespace | ||
{ | ||
|
||
#define RETURN_FALSE_UNLESS(cond) \ | ||
if (not(cond)) \ | ||
return false; | ||
|
||
inline bool is_single_element(const luci::CircleConst *node) | ||
{ | ||
return ((node->rank() == 1 || node->rank() == 0) && node->size<loco::DataType::FLOAT32>() == 1); | ||
} | ||
|
||
inline void update_with_single_element(luci::CircleConst *fused_node, | ||
const luci::CircleConst *multiplication) | ||
{ | ||
for (uint32_t i = 0; i < fused_node->size<loco::DataType::FLOAT32>(); i++) | ||
{ | ||
fused_node->at<loco::DataType::FLOAT32>(i) *= multiplication->at<loco::DataType::FLOAT32>(0); | ||
} | ||
} | ||
|
||
luci::CircleConst *gen_fused_weights(luci::CircleConst *weights, | ||
const luci::CircleConst *multiplication) | ||
{ | ||
auto fused_weights = luci::clone(weights); | ||
// Single element multiplication: | ||
if (is_single_element(multiplication)) | ||
{ | ||
update_with_single_element(fused_weights, multiplication); | ||
} | ||
// N-size multiplication: | ||
else | ||
{ | ||
// Go along channels, multiplication size is ensured to be compatible with channels. | ||
auto count = fused_weights->dim(0).value(); | ||
auto size = fused_weights->dim(fused_weights->rank() - 1).value(); | ||
float val; | ||
for (uint32_t c = 0; c < count; c++) | ||
{ | ||
val = multiplication->at<loco::DataType::FLOAT32>(c); | ||
for (uint32_t i = 0; i < size; i++) | ||
{ | ||
fused_weights->at<loco::DataType::FLOAT32>(c * size + i) *= val; | ||
} | ||
} | ||
} | ||
return fused_weights; | ||
} | ||
|
||
luci::CircleConst *gen_fused_bias(luci::CircleConst *bias, const luci::CircleConst *multiplication) | ||
{ | ||
auto fused_bias = luci::clone(bias); | ||
// Single element multiplication: | ||
if (is_single_element(multiplication)) | ||
{ | ||
update_with_single_element(fused_bias, multiplication); | ||
} | ||
// N-size multiplication: | ||
else | ||
{ | ||
// Go along channels, multiplication size is ensured to be compatible with channels. | ||
for (uint32_t i = 0; i < fused_bias->size<loco::DataType::FLOAT32>(); i++) | ||
{ | ||
fused_bias->at<loco::DataType::FLOAT32>(i) *= multiplication->at<loco::DataType::FLOAT32>(i); | ||
} | ||
} | ||
return fused_bias; | ||
} | ||
|
||
/** | ||
* Fuse Mul to FullyConnected if the multiplied value is a channel(last dimension)-wise constant | ||
* | ||
* BEFORE | ||
* | | ||
* [CircleFullyConnected] | ||
* | | ||
* [CircleMul] | ||
* | | ||
* | ||
* AFTER | ||
* | | ||
* [CircleFullyConnected] [CircleMul] (dead) | ||
* | | ||
* | ||
*/ | ||
bool fuse_mul_with_fc(luci::CircleMul *mul) | ||
{ | ||
// Sanity check: | ||
RETURN_FALSE_UNLESS(mul); | ||
// Allow Mul node only with FLOAT32 data type: | ||
RETURN_FALSE_UNLESS(mul->dtype() == loco::DataType::FLOAT32); | ||
// Check if any FC node connects to Mul. | ||
// Find the pattern of Mul(FC, CircleConst): | ||
luci::CircleFullyConnected *fc = nullptr; | ||
luci::CircleConst *multiplication = nullptr; | ||
RETURN_FALSE_UNLESS(luci::fill(&fc, &multiplication).with_commutative_args_of(mul)); | ||
/** | ||
* Make sure that FullyConnected has only one successor. | ||
* | ||
* If the FullyConnected output is connected to more nodes, | ||
* this pass will replace node with new fused FullyConnected. | ||
* Thus pass success will only introduce extra FullyConnected | ||
* without reducing overall number of nodes. | ||
* Which tends to increase model's size and degrades model's performance. | ||
* Thus one successor is required to benefit from this pass. | ||
* | ||
* Example graph that illustrates the described scenario: | ||
* | ||
* BEFORE | ||
* | | ||
* [CircleFullyConnected] | ||
* | | ||
* +-------+----------------+ | ||
* | | | ||
* | | | ||
* [Other Node] [CircleMul] | ||
* | | | ||
* | ||
* AFTER | ||
* | | ||
* [CircleFullyConnected] | ||
* | | ||
* +-------+-----------------------+ | ||
* | | | ||
* | | | ||
* [Other Node] [New CircleFullyConnected Fused with Mul] | ||
* | | | ||
* | ||
*/ | ||
RETURN_FALSE_UNLESS(loco::succs(fc).size() == 1); | ||
// Allow only FLOAT32 data type: | ||
RETURN_FALSE_UNLESS(fc->dtype() == loco::DataType::FLOAT32); | ||
// Allow only without activation functions as values are going to | ||
// be multiplied before activation function. | ||
RETURN_FALSE_UNLESS(fc->fusedActivationFunction() == luci::FusedActFunc::NONE); | ||
// Check for weights being Constant: | ||
auto weights = dynamic_cast<luci::CircleConst *>(fc->weights()); | ||
RETURN_FALSE_UNLESS(weights); | ||
// Get rank of multiplication: | ||
auto rank = multiplication->rank(); | ||
// Check that all dimensions are ones, checks broadcast capabilites. | ||
// Last dimesion of multiplication must be compatible with FC. | ||
// N-D case (N>1): | ||
if (multiplication->rank() > 1) | ||
{ | ||
// Check channel-wise broadcasting: | ||
for (uint32_t i = 0; i < rank - 1; i++) | ||
RETURN_FALSE_UNLESS(multiplication->dim(i).value() == 1); | ||
// Check the last dimesion of Mul is the same with the first dimension of FullyConnected | ||
RETURN_FALSE_UNLESS(multiplication->dim(rank - 1) == weights->dim(0)); | ||
} | ||
// 1-D or scalar case: | ||
else if (multiplication->rank() == 1) | ||
{ | ||
RETURN_FALSE_UNLESS(multiplication->size<loco::DataType::FLOAT32>() == 1 || | ||
multiplication->size<loco::DataType::FLOAT32>() == weights->dim(0)); | ||
} | ||
else if (multiplication->rank() == 0) | ||
{ | ||
RETURN_FALSE_UNLESS(multiplication->size<loco::DataType::FLOAT32>() == 1); | ||
} | ||
|
||
// Only supports: | ||
// (1) constant bias | ||
// (2) no bias | ||
auto bias = loco::must_cast<luci::CircleNode *>(fc->bias()); | ||
if (bias->opcode() == luci::CircleOpcode::CIRCLECONST) | ||
{ | ||
// Create new bias to be updated with values: | ||
auto const_bias = dynamic_cast<luci::CircleConst *>(fc->bias()); | ||
RETURN_FALSE_UNLESS(const_bias) | ||
RETURN_FALSE_UNLESS(const_bias->dtype() == loco::DataType::FLOAT32); | ||
// Create new bias with updated values and replace: | ||
auto fused_bias = gen_fused_bias(const_bias, multiplication); | ||
fc->bias(fused_bias); | ||
} | ||
else if (bias->opcode() != luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE) | ||
{ | ||
return false; | ||
} | ||
|
||
// Create new weights with updated values and replace: | ||
auto fused_weights = gen_fused_weights(weights, multiplication); | ||
fc->weights(fused_weights); | ||
|
||
// Set origin and copy Activation Function if exisitng: | ||
fc->fusedActivationFunction(mul->fusedActivationFunction()); | ||
luci::add_origin(fc, luci::get_origin(mul)); | ||
|
||
replace(mul).with(fc); | ||
|
||
return true; | ||
} | ||
|
||
} // namespace | ||
|
||
namespace luci | ||
{ | ||
|
||
bool FuseMulWithFullyConnectedPass::run(loco::Graph *g) | ||
{ | ||
bool changed = false; | ||
for (auto node : loco::active_nodes(loco::output_nodes(g))) | ||
{ | ||
if (auto mul = dynamic_cast<luci::CircleMul *>(node)) | ||
{ | ||
if (fuse_mul_with_fc(mul)) | ||
changed = true; | ||
} | ||
} | ||
|
||
return changed; | ||
} | ||
|
||
} // namespace luci |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
plz add
FuseMulWithFullyConnectedPass.test.cpp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the unit test, you should add
count(positive) <= count(negative)
where negative method name ends with_NEG