Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pnnx print flops memops count #5693

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 195 additions & 1 deletion tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2862,4 +2862,198 @@ const Operand* Graph::get_operand(const std::string& name) const
return 0;
}

} // namespace pnnx

// ������
void calculate_conv_flops_and_memory(const pnnx::Operator& op)
{
int input_channels = op.params.at("input_channels").i;
int input_height = op.params.at("input_height").i;
int input_width = op.params.at("input_width").i;

int kernel_size = op.params.at("kernel_size").i;
int stride = op.params.at("stride").i;
int padding = op.params.at("padding").i;
int output_channels = op.params.at("output_channels").i;

int output_height = (input_height + 2 * padding - kernel_size) / stride + 1;
int output_width = (input_width + 2 * padding - kernel_size) / stride + 1;

int64_t flops = output_height * output_width * output_channels * input_channels * kernel_size * kernel_size;
op.attrs["flops"] = pnnx::Attribute(flops); //���Գ�ʼ��

int64_t memory_ops = 2 * (output_height * output_width * output_channels * input_channels * kernel_size * kernel_size) + (output_channels * output_height * output_width);
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

// ȫ���Ӳ�
void calculate_fc_flops_and_memory(const pnnx::Operator& op)
{
int input_size = op.params.at("input_size").i;
int output_size = op.params.at("output_size").i;

// FLOPS = 2 * (input_size * output_size) (�˷��ͼӷ�) �����Ǽ����,�����м�������
// FLOPS = 2 * (input_size * output_size) + output_size
int64_t flops = 2 * input_size * output_size;
op.attrs["flops"] = pnnx::Attribute(flops);
// ����Ȩ�ء�ƫ�õ�
int64_t weight_memory_access = input_size * output_size;
int64_t bias_memory_access = output_size;
int64_t input_memory_access = input_size;
int64_t output_memory_access = output_size;

int64_t memory_ops = weight_memory_access + bias_memory_access + input_memory_access + output_memory_access;
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

// �ػ���
void calculate_max_pool_flops_and_memory(const pnnx::Operator& op)
{
int input_height = op.params.at("input_height").i;
int input_width = op.params.at("input_width").i;
int channels = op.params.at("channnels").i;
int kernel_size = op.params.at("kernel_size").i;
int stride = op.params.at("stride").i;
int padding = op.params.at("padding").i;
int output_height = (input_height + 2 * padding - kernel_size) / stride + 1;
int output_width = (input_width + 2 * padding - kernel_size) / stride + 1;

// ����ÿ�αȽ���Ϊһ��flop����
int64_t flops = output_height * output_width * channels * kernel_size * kernel_size * 1;
op.attrs["flops"] = pnnx::Attribute(flops);
int64_t memory_ops = input_height * input_width * channels + output_height * output_width * channels;
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}
void calculate_avg_pool_flops_and_memory(const pnnx::Operator& op)
{
int input_height = op.params.at("input_height").i;
int input_width = op.params.at("input_width").i;
int channels = op.params.at("channnels").i;
int kernel_size = op.params.at("kernel_size").i;
int stride = op.params.at("stride").i;
int padding = op.params.at("padding").i;
int output_height = (input_height + 2 * padding - kernel_size) / stride + 1;
int output_width = (input_width + 2 * padding - kernel_size) / stride + 1;

// For addition and one division (division FLOP cost is negligible)
int64_t flops = output_height * output_width * channels * kernel_size * kernel_size;
op.attrs["flops"] = pnnx::Attribute(flops);
int64_t memory_ops = input_height * input_width * channels + output_height * output_width * channels;
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

// �����
void calculate_activation_flops_and_memory(const pnnx::Operator& op)
{
int input_size = op.params.at("input_size").i;
int64_t flops = input_size;
op.attrs["flops"] = pnnx::Attribute(flops);
int64_t memory_ops = flops * 2; // ��ȡ�����д�����
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

// ������һ����
void calculate_bn_flops_and_memory(const pnnx::Operator& op)
{
int input_channels = op.params.at("input_channels").i;
int input_height = op.params.at("input_height").i;
int input_width = op.params.at("input_width").i;
// FLOPS = 5 * input_channels * (input_height * input_width) (��һ������)
int64_t flops = 5 * input_channels * (input_height * input_width);
op.attrs["flops"] = pnnx::Attribute(flops);
int64_t memory_ops = 6 * input_channels * (input_height * input_width); //���Ƕ�ȡ���롢��ֵ������������ӡ�ƫ�ƺ�д�����
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

// ������
void calculate_dropout_flops_and_memory(const pnnx::Operator& op)
{
int64_t flops = op.params.at("input_size").i;
op.attrs["flops"] = pnnx::Attribute(flops);
int64_t memory_ops = op.params.at("input_size").i * 2;
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

// LSTM Layer
void calculate_lstm_flops_and_memory(const pnnx::Operator& op)
{
int input_size = op.params.at("input_size").i;
int hidden_size = op.params.at("hidden_size").i;
int64_t flops = 10 * input_size * hidden_size;
op.attrs["flops"] = pnnx::Attribute(flops);
int64_t memory_ops = input_size + 2 * hidden_size + 4 * hidden_size * (input_size + hidden_size) + 2 * hidden_size;
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

// Embedding Layer
void calculate_embedding_flops_and_memory(const pnnx::Operator& op)
{
int input_vocab_size = op.params.at("vocab_size").i;
int embedding_size = op.params.at("embedding_size").i;
int64_t flops = input_vocab_size * embedding_size;
op.attrs["flops"] = pnnx::Attribute(flops);
int64_t memory_ops = input_vocab_size + input_vocab_size * embedding_size;
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

// Layer Normalization Layer
void calculate_layer_norm_flops_and_memory(const pnnx::Operator& op)
{
int input_size = op.params.at("input_size").i;
// ���һ���漰��һ�������Ų�����ÿ��Ԫ�����β�����һ�μ�ȥ��ֵ��һ�γ��Է��
int64_t flops = 5 * input_size;
op.attrs["flops"] = pnnx::Attribute(flops);
int64_t memory_ops = 4 * input_size;
op.attrs["memory_ops"] = pnnx::Attribute(memory_ops);
}

void calculate_flops_and_memory_for_operator(Operator* op)
{
if (op->type == "Convolution")
{
calculate_conv_flops_and_memory(op);
}
else if (op->type == "FullyConnected")
{
calculate_fc_flops_and_memory(op);
}
else if (op->type == "Pooling")
{
calculate_pool_flops_and_memory(op);
}
else if (op->type == "Activation")
{
calculate_activation_flops_and_memory(op);
}
else if (op->type == "BatchNormalization")
{
calculate_bn_flops_and_memory(op);
}
else if (op->type == "LSTM")
{
calculate_lstm_flops_and_memory(op);
}
else if (op->type == "Embedding")
{
calculate_embedding_flops_and_memory(op);
}
else if (op->type == "LayerNormalization")
{
calculate_layer_norm_flops_and_memory(op);
}
}
void Graph::calculate_total_flops_and_memory_ops()
{
int64_t total_flops = 0;
int64_t total_memory_ops = 0;

for (Operator* op : ops)
{
calculate_flops_and_memory_for_operator(op);
total_flops += op->flops;
total_memory_ops += op->memory_ops;
}
std::cerr << "Total FLOPS: " << total_flops / 1e6 << "M" << std::endl;
std::cerr << "Total Memory Operations: " << total_memory_ops / 1e6 << "M" << std::endl;
}

} // namespace pnnx
8 changes: 7 additions & 1 deletion tools/pnnx/src/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,13 @@ class Attribute
{
public:
Attribute()
: type(0)
: type(0), flops(0), memory_ops(0)
{
}

int64_t flops;
int64_t memory_ops;

#if BUILD_TORCH2PNNX
Attribute(const at::Tensor& t);
#endif
Expand Down Expand Up @@ -304,6 +307,9 @@ class Operator
std::map<std::string, Parameter> params;
std::map<std::string, Attribute> attrs;

int64_t flops;
int64_t memory_ops;

private:
friend class Graph;
Operator()
Expand Down