File tree Expand file tree Collapse file tree 1 file changed +14
-1
lines changed
paddle/cinn/hlir/dialect/operator/transforms/lowering_pass Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -253,6 +253,17 @@ CreateGroupShapeOrDataExprs(
253253 InferSymbolicShapeForOperation (op, &local_shape_analysis);
254254 }
255255
256+ auto broadcast_contains = [](const symbol::DimExpr& dimexpr,
257+ const symbol::DimExpr& target) {
258+ auto broadcast =
259+ std::get_if<symbol::Broadcast<symbol::DimExpr>>(&dimexpr.variant ());
260+ if (broadcast == nullptr ) return false ;
261+ for (const auto & operand : *(broadcast->operands )) {
262+ if (operand == target) return true ;
263+ }
264+ return false ;
265+ };
266+
256267 // Add shape constraints after infer.
257268 auto & mut_substitute_dimexpr_map = group->mut_substitute_dimexpr_map ();
258269 for (auto * op : group->ops ()) {
@@ -264,7 +275,9 @@ CreateGroupShapeOrDataExprs(
264275 if (global_result_shape.size () != local_result_shape.size ()) continue ;
265276 for (size_t i = 0 ; i < global_result_shape.size (); ++i) {
266277 if (global_result_shape[i] != local_result_shape[i] &&
267- !global_result_shape[i].isa <std::int64_t >()) {
278+ !global_result_shape[i].isa <std::int64_t >() &&
279+ !broadcast_contains (local_result_shape[i],
280+ global_result_shape[i])) {
268281 mut_substitute_dimexpr_map[global_result_shape[i]] =
269282 local_result_shape[i];
270283 }
You can’t perform that action at this time.
0 commit comments