Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save more foldable constants in file for reducing memory usage #4337

Merged
merged 1 commit into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tools/pnnx/src/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <initializer_list>
#include <map>
#include <set>
#include <string>
#include <vector>

Expand Down
10 changes: 7 additions & 3 deletions tools/pnnx/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,9 @@ int main(int argc, char** argv)

fprintf(stderr, "############# pass_level0\n");

std::map<std::string, pnnx::Attribute> foldable_constants;
pnnx::pass_level0(mod, g, input_tensors, input_tensors2, module_operators, ptpath, device, foldable_constants);
std::set<std::string> foldable_constants;
std::string foldable_constants_zippath = ptbase + ".foldable_constants.zip";
pnnx::pass_level0(mod, g, input_tensors, input_tensors2, module_operators, ptpath, device, foldable_constants, foldable_constants_zippath);

// g->dump();

Expand Down Expand Up @@ -403,9 +404,12 @@ int main(int argc, char** argv)
{
fprintf(stderr, "############# pass_level5\n");

pnnx::pass_level5(pnnx_graph, foldable_constants);
pnnx::pass_level5(pnnx_graph, foldable_constants, foldable_constants_zippath);
}

// delete foldable_constants_zippath
remove(foldable_constants_zippath.c_str());

pnnx_graph.save(pnnxparampath, pnnxbinpath);

pnnx_graph.python(pnnxpypath, pnnxbinpath);
Expand Down
4 changes: 2 additions & 2 deletions tools/pnnx/src/pass_level0.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

namespace pnnx {

void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, const std::string& device, std::map<std::string, Attribute>& foldable_constants)
void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, const std::string& device, std::set<std::string>& foldable_constants, const std::string& foldable_constants_zippath)
{
inline_block(g, module_operators);

Expand All @@ -31,7 +31,7 @@ void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Grap

if (!input_tensors.empty())
{
shape_inference(mod, g, input_tensors, input_tensors2, module_operators, ptpath, device, foldable_constants);
shape_inference(mod, g, input_tensors, input_tensors2, module_operators, ptpath, device, foldable_constants, foldable_constants_zippath);
}
}

Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_level0.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

namespace pnnx {

void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, const std::string& device, std::map<std::string, Attribute>& foldable_constants);
void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, const std::string& device, std::set<std::string>& foldable_constants, const std::string& foldable_constants_zippath);

} // namespace pnnx

Expand Down
59 changes: 27 additions & 32 deletions tools/pnnx/src/pass_level0/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "shape_inference.h"
#include <unordered_set>

#include "storezip.h"
#include "pass_level0/constant_unpooling.h"
#include "pass_level0/inline_block.h"
#include "pass_level0/reset_device.h"
Expand Down Expand Up @@ -78,7 +79,7 @@ static bool value_link_output(const torch::jit::Value* v, const std::vector<torc
return false;
}

void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, const std::string& device, std::map<std::string, Attribute>& foldable_constants)
void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, const std::string& device, std::set<std::string>& foldable_constants, const std::string& foldable_constants_zippath)
{
// collect all intermediate output tensors
std::vector<std::unordered_set<std::string> > more_value_names;
Expand Down Expand Up @@ -142,7 +143,8 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::
inputs2.push_back(it);
}

std::map<torch::jit::Value*, at::Tensor> output_tensors;
StoreZipWriter zip;
zip.open(foldable_constants_zippath);

for (size_t p = 0; p < more_value_names.size(); p++)
{
Expand Down Expand Up @@ -174,7 +176,7 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::
if (value_names.find(v->debugName()) != value_names.end())
{
values2.push_back(v);
fprintf(stderr, "%s ", v->debugName().c_str());
// fprintf(stderr, "%s ", v->debugName().c_str());
}
}
}
Expand Down Expand Up @@ -204,7 +206,16 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::
// check if value that does not depend on inputs
if (!value_link_input(v, g_inputs, true) && value_link_output(v, g_outputs))
{
output_tensors[v] = t;
// output_tensors[v] = t;
const int ndim = (int)t.dim();
if (ndim > 0)
{
// fprintf(stderr, "foldable_constant %s\n", v->debugName().c_str());
foldable_constants.insert(v->debugName());

at::Tensor t2 = t.cpu().contiguous();
zip.write_file(v->debugName(), (const char*)t2.data_ptr(), t2.nbytes());
}
}
}
}
Expand Down Expand Up @@ -242,12 +253,23 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::
// check if value that does not depend on inputs
if (!value_link_input(v, g_inputs, false) && value_link_output(v, g_outputs))
{
output_tensors[v] = t;
// output_tensors[v] = t;
const int ndim = (int)t.dim();
if (ndim > 0)
{
// fprintf(stderr, "foldable_constant %s\n", v->debugName().c_str());
foldable_constants.insert(v->debugName());

at::Tensor t2 = t.cpu().contiguous();
zip.write_file(v->debugName(), (const char*)t2.data_ptr(), t2.nbytes());
}
}
}
}
}

zip.close();

if (input_tensors2.empty())
{
for (size_t i = 0; i < input_tensors.size(); i++)
Expand Down Expand Up @@ -280,33 +302,6 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::
graph->inputs()[1 + i]->setType(finaltype);
}
}

for (auto xx : output_tensors)
{
auto v = xx.first;
auto tensor = xx.second;

bool link_to_output = false;
for (size_t i = 0; i < v->uses().size(); i++)
{
auto node = v->uses()[i].user;
for (auto x : node->outputs())
{
if (output_tensors.find(x) == output_tensors.end())
{
link_to_output = true;
break;
}
}
}

const int ndim = (int)tensor.dim();
if (link_to_output && ndim > 0)
{
fprintf(stderr, "foldable_constant %s\n", v->debugName().c_str());
foldable_constants[v->debugName()] = Attribute(tensor);
}
}
}

} // namespace pnnx
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_level0/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@

namespace pnnx {

void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, const std::string& device, std::map<std::string, Attribute>& foldable_constants);
void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, const std::string& device, std::set<std::string>& foldable_constants, const std::string& foldable_constants_zippath);

} // namespace pnnx
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_level3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

namespace pnnx {

void pass_level3(Graph& g, const std::map<std::string, Attribute>& foldable_constants)
void pass_level3(Graph& g, const std::set<std::string>& foldable_constants)
{
assign_unique_name(g);

Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_level3.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

namespace pnnx {

void pass_level3(Graph& g, const std::map<std::string, Attribute>& foldable_constants);
void pass_level3(Graph& g, const std::set<std::string>& foldable_constants);

} // namespace pnnx

Expand Down
6 changes: 3 additions & 3 deletions tools/pnnx/src/pass_level3/fuse_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ static bool operand_maybe_tensor(const Operand* operand)
return true;
}

static bool operand_is_foldable(const Operand* operand, const std::map<std::string, Attribute>& foldable_constants)
static bool operand_is_foldable(const Operand* operand, const std::set<std::string>& foldable_constants)
{
if (foldable_constants.find(operand->name) != foldable_constants.end())
return true;
Expand All @@ -134,7 +134,7 @@ static bool operand_is_foldable(const Operand* operand, const std::map<std::stri
return true;
}

static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, std::vector<Operand*>& inputs, const std::map<std::string, Attribute>& foldable_constants, bool checksubgraph = true)
static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, std::vector<Operand*>& inputs, const std::set<std::string>& foldable_constants, bool checksubgraph = true)
{
// fprintf(stderr, "fuse_expression %s\n", operand->name.c_str());

Expand Down Expand Up @@ -412,7 +412,7 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s
}
}

void fuse_expression(Graph& graph, const std::map<std::string, Attribute>& foldable_constants)
void fuse_expression(Graph& graph, const std::set<std::string>& foldable_constants)
{
int pnnx_expr_index = 0;

Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_level3/fuse_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@

namespace pnnx {

void fuse_expression(Graph& graph, const std::map<std::string, Attribute>& foldable_constants);
void fuse_expression(Graph& graph, const std::set<std::string>& foldable_constants);

} // namespace pnnx
4 changes: 2 additions & 2 deletions tools/pnnx/src/pass_level5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

namespace pnnx {

void pass_level5(Graph& g, const std::map<std::string, Attribute>& foldable_constants)
void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, const std::string& foldable_constants_zippath)
{
eval_expression(g);

Expand Down Expand Up @@ -92,7 +92,7 @@ void pass_level5(Graph& g, const std::map<std::string, Attribute>& foldable_cons

fuse_channel_shuffle(g);

fold_constants(g, foldable_constants);
fold_constants(g, foldable_constants, foldable_constants_zippath);

fuse_index_expression(g);

Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_level5.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

namespace pnnx {

void pass_level5(Graph& g, const std::map<std::string, Attribute>& foldable_constants);
void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, const std::string& foldable_constants_zippath);

} // namespace pnnx

Expand Down
18 changes: 16 additions & 2 deletions tools/pnnx/src/pass_level5/fold_constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
#include "fold_constants.h"
#include <unordered_set>

#include "storezip.h"
#include "pass_level4/dead_code_elimination.h"

namespace pnnx {

void fold_constants(Graph& graph, const std::map<std::string, Attribute>& foldable_constants)
void fold_constants(Graph& graph, const std::set<std::string>& foldable_constants, const std::string& foldable_constants_zippath)
{
StoreZipReader zip;
zip.open(foldable_constants_zippath);

for (size_t i = 0; i < graph.operands.size(); i++)
{
Operand* operand = graph.operands[i];
Expand All @@ -36,13 +40,23 @@ void fold_constants(Graph& graph, const std::map<std::string, Attribute>& foldab
// replace producer with attribute
Operator* op_new = graph.new_operator_before("pnnx.Attribute", std::string("pnnx_fold_") + name, op);

op_new->attrs[std::string("pnnx_fold_") + name] = foldable_constants.at(name);
op_new->attrs[std::string("pnnx_fold_") + name] = Attribute();

Attribute& t2 = op_new->attrs[std::string("pnnx_fold_") + name];
t2.type = operand->type;
t2.shape = operand->shape;
size_t size = zip.get_file_size(name);
t2.data.resize(size);
zip.read_file(name, t2.data.data());

op_new->outputs.push_back(operand);
operand->producer = op_new;

op->outputs.clear();
}

zip.close();

// dce
dead_code_elimination(graph);
}
Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_level5/fold_constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@

namespace pnnx {

void fold_constants(Graph& graph, const std::map<std::string, Attribute>& foldable_constants);
void fold_constants(Graph& graph, const std::set<std::string>& foldable_constants, const std::string& foldable_constants_zippath);

} // namespace pnnx