Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Infer Symbolic Shape No.125】【BUAA】Add BroadcastTensor, changed 3 files #67744

Closed
wants to merge 86 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
163571f
add 3 files
Whsjrczr Aug 8, 2024
75facec
update .cc
Whsjrczr Aug 9, 2024
5972d24
undo change in .cc file
Whsjrczr Aug 9, 2024
545b90b
update num_samples in .cc
Whsjrczr Aug 9, 2024
0b44b7e
attribute
Whsjrczr Aug 9, 2024
9e0d933
update
Whsjrczr Aug 9, 2024
bf533d3
update
Whsjrczr Aug 9, 2024
2880140
update
Whsjrczr Aug 9, 2024
e2a5d42
update
Whsjrczr Aug 9, 2024
5dc6a76
update
Whsjrczr Aug 9, 2024
b21f08f
update
Whsjrczr Aug 9, 2024
89ff139
update
Whsjrczr Aug 9, 2024
ecc7735
updata
Whsjrczr Aug 9, 2024
b56968e
updata
Whsjrczr Aug 9, 2024
b7a8875
update
Whsjrczr Aug 9, 2024
30c9e3f
update
Whsjrczr Aug 9, 2024
d6b1104
update
Whsjrczr Aug 9, 2024
652573e
update
Whsjrczr Aug 9, 2024
475014d
update
Whsjrczr Aug 9, 2024
97caa5e
update
Whsjrczr Aug 9, 2024
7945ac9
u
Whsjrczr Aug 9, 2024
620525c
batch_function
Whsjrczr Aug 12, 2024
16be43e
bincount
Whsjrczr Aug 12, 2024
47c5d45
update batchfc
Whsjrczr Aug 12, 2024
c4142fe
update batchfc
Whsjrczr Aug 12, 2024
6379805
update EQ
Whsjrczr Aug 12, 2024
e90891f
update {-1}
Whsjrczr Aug 12, 2024
0c9b893
2 api
Whsjrczr Aug 12, 2024
8b3ad1d
update binary with output_size
Whsjrczr Aug 12, 2024
949b95c
undo change
Whsjrczr Aug 12, 2024
f2c974a
{-1}
Whsjrczr Aug 12, 2024
3f1c49f
add Bincount
Whsjrczr Aug 12, 2024
7a5a3ed
update class_center_sample
Whsjrczr Aug 12, 2024
eef3cd6
delete batch_fc
Whsjrczr Aug 12, 2024
0ec2002
update broadcast tensor
Whsjrczr Aug 13, 2024
28a64d7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Whsjrczr Aug 13, 2024
5d447ea
update batchnormop
Whsjrczr Aug 13, 2024
7c24bfe
update batchnormop
Whsjrczr Aug 13, 2024
877670c
changed bn
Whsjrczr Aug 13, 2024
4b2cc16
changed bn
Whsjrczr Aug 13, 2024
79badab
update Bmm
Whsjrczr Aug 14, 2024
4a3ad3f
add common::
Whsjrczr Aug 14, 2024
dfb5ed1
update broadcast with vector<MetaTensor>
Whsjrczr Aug 14, 2024
0231a79
update broadcast with vector<MetaTensor>
Whsjrczr Aug 14, 2024
40b81b1
undo change, delete a line
Whsjrczr Aug 14, 2024
82304ca
add .shape() and .data()
Whsjrczr Aug 14, 2024
1fe5b1d
change into .shape()
Whsjrczr Aug 14, 2024
73c1917
delete <>
Whsjrczr Aug 14, 2024
d5c66bb
print
Whsjrczr Aug 14, 2024
2b1cc89
new print
Whsjrczr Aug 14, 2024
b24600e
new print
Whsjrczr Aug 14, 2024
d17efd5
<>
Whsjrczr Aug 14, 2024
c606847
add
Whsjrczr Aug 14, 2024
9685c7d
unduo _
Whsjrczr Aug 14, 2024
a8f1277
Update DimExpr
Whsjrczr Aug 15, 2024
2774f51
Update unary_infer_sym.cc
Whsjrczr Aug 15, 2024
52ef749
out_unknown
Whsjrczr Aug 16, 2024
e53bc52
Expr -> Exprs
Whsjrczr Aug 16, 2024
b09ad2c
{{out_dims}} -> out_dims
Whsjrczr Aug 16, 2024
a010e0f
delete ()
Whsjrczr Aug 16, 2024
88bb994
Merge branch 'api1' of https://github.com/Whsjrczr/Paddle into develop
Whsjrczr Aug 19, 2024
eacc7f9
Merge branch 'api2' of https://github.com/Whsjrczr/Paddle into develop
Whsjrczr Aug 19, 2024
95bb57b
update dimexpr
Whsjrczr Aug 19, 2024
f58bd89
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Whsjrczr Aug 20, 2024
de5ed23
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Whsjrczr Aug 21, 2024
ad1f042
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Whsjrczr Aug 22, 2024
1fa61f6
restore changes
Whsjrczr Aug 22, 2024
f7a4020
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Whsjrczr Aug 22, 2024
43a0978
undo
Whsjrczr Aug 22, 2024
16a749d
undo some changes
Whsjrczr Aug 22, 2024
da39416
update broadcast
Whsjrczr Aug 25, 2024
19de210
Merge branch 'PaddlePaddle:develop' into develop
Whsjrczr Aug 26, 2024
b622887
merge
Whsjrczr Aug 27, 2024
e12c906
delete
Whsjrczr Aug 27, 2024
1141768
delete
Whsjrczr Aug 27, 2024
06981ce
add type change
Whsjrczr Sep 1, 2024
f21eab4
change shape
Whsjrczr Sep 2, 2024
a180cbc
int -> size_t
Whsjrczr Sep 3, 2024
c78ae10
changed logic
Whsjrczr Sep 4, 2024
8724fc7
j -> bound-j-1
Whsjrczr Sep 5, 2024
e74f178
int -> size_t
Whsjrczr Sep 5, 2024
8ac6e47
changed logic
Whsjrczr Sep 9, 2024
1f80e90
rerun
Whsjrczr Sep 11, 2024
e02ca69
rerun
Whsjrczr Sep 11, 2024
51af153
rerun
Whsjrczr Sep 11, 2024
a49eee5
rerun
Whsjrczr Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -651,12 +651,43 @@ bool BilinearOpInferSymbolicShape(
// return true;
// }

// bool BroadcastTensorsOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }
bool BroadcastTensorsOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &input_shape_or_data_list =
infer_context->GetShapeOrDataForValue(op->operand_source(0))
.dyn_cast<symbol::TensorListShapeOrDataDimExprs>();

// 1. Find Output rank = max(Inputs rank)
size_t target_rank = 0;
for (const auto &input_shape_or_data : input_shape_or_data_list) {
size_t tmp_rank = input_shape_or_data.shape().size();
target_rank = std::max(int64_t(target_rank), int64_t(tmp_rank));
}

// 2. Output dim(axis=x) = max(Inputs dim(axis=x))
std::vector<symbol::DimExpr> out_shape;
symbol::DimExprBuilder builder;
for (size_t i = 0; i < target_rank; i++) {
auto tmp_dim = symbol::DimExpr{1};
for (const auto &input_shape_or_data : input_shape_or_data_list) {
size_t axis = i - target_rank + input_shape_or_data.size();
if (axis >= 0) {
infer_context->AddBroadcastableCstr(input_shape_or_data.shape()[axis],
tmp_dim);
tmp_dim = builder.Broadcast(input_shape_or_data.shape()[axis], tmp_dim);
}
}
out_shape.emplace_back(tmp_dim);
}

symbol::TensorListShapeOrDataDimExprs out_shapes;
for (size_t i = 0; i < input_shape_or_data_list.size(); i++) {
out_shapes.emplace_back(out_shape);
}
infer_context->SetShapeOrDataForValue(
op->result(0), symbol::ShapeOrDataDimExprs{out_shapes});
return true;
}

bool BilinearInterpOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Addmm_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(AddN)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Auc)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(AssignPos)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(BroadcastTensors)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BroadcastTensors)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchFc)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm_)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@
func: broadcast_tensors
data_type : input
backward: broadcast_tensors_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : c_allgather
args : (Tensor x, int ring_id, int nranks, bool use_calc_stream)
Expand Down