@@ -118,6 +118,53 @@ static T getPerfectlyNested(Operation *op) {
118118 return nullptr ;
119119}
120120
121+ // VerifyTargetTeamsWorkdistribute method verifies that
122+ // omp.target { teams { workdistribute { ... } } } is well formed
123+ // and fails for function calls that don't have lowering implemented yet.
124+ static bool
125+ VerifyTargetTeamsWorkdistribute (omp::WorkdistributeOp workdistribute) {
126+ OpBuilder rewriter (workdistribute);
127+ auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp ());
128+ if (!teams) {
129+ workdistribute.emitError () << " workdistribute not nested in teams\n " ;
130+ return false ;
131+ }
132+ if (workdistribute.getRegion ().getBlocks ().size () != 1 ) {
133+ workdistribute.emitError () << " workdistribute with multiple blocks\n " ;
134+ return false ;
135+ }
136+ if (teams.getRegion ().getBlocks ().size () != 1 ) {
137+ workdistribute.emitError () << " teams with multiple blocks\n " ;
138+ return false ;
139+ }
140+ omp::TargetOp targetOp = dyn_cast<omp::TargetOp>(teams->getParentOp ());
141+ // return if not omp.target
142+ if (!targetOp)
143+ return true ;
144+
145+ for (auto &op : workdistribute.getOps ()) {
146+ if (auto callOp = dyn_cast<fir::CallOp>(op)) {
147+ if (isRuntimeCall (&op)) {
148+ auto funcName = (*callOp.getCallee ()).getRootReference ().getValue ();
149+ // _FortranAAssign is handled. Other runtime calls are not supported
150+ // in omp.workdistribute yet.
151+ if (funcName == " _FortranAAssign" )
152+ continue ;
153+ else
154+ workdistribute.emitError ()
155+ << " Runtime call " << funcName
156+ << " lowering not supported for workdistribute yet." ;
157+ return false ;
158+ } else {
159+ workdistribute.emitError () << " Non-runtime fir.call lowering not "
160+ " supported in workdistribute yet." ;
161+ return false ;
162+ }
163+ }
164+ }
165+ return true ;
166+ }
167+
121168// FissionWorkdistribute method finds the parallelizable ops
122169// within teams {workdistribute} region and moves them to their
123170// own teams{workdistribute} region.
@@ -154,18 +201,10 @@ static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
154201 OpBuilder rewriter (workdistribute);
155202 auto loc = workdistribute->getLoc ();
156203 auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp ());
157- if (!teams) {
158- emitError (loc, " workdistribute not nested in teams\n " );
159- return false ;
160- }
161- if (workdistribute.getRegion ().getBlocks ().size () != 1 ) {
162- emitError (loc, " workdistribute with multiple blocks\n " );
163- return false ;
164- }
165- if (teams.getRegion ().getBlocks ().size () != 1 ) {
166- emitError (loc, " teams with multiple blocks\n " );
167- return false ;
168- }
204+
205+ omp::TargetOp targetOp;
206+ // Get the target op parent of teams
207+ targetOp = dyn_cast<omp::TargetOp>(teams->getParentOp ());
169208
170209 auto *teamsBlock = &teams.getRegion ().front ();
171210 bool changed = false ;
@@ -1744,6 +1783,11 @@ class LowerWorkdistributePass
17441783 auto moduleOp = getOperation ();
17451784 bool changed = false ;
17461785 SetVector<omp::TargetOp> targetOpsToProcess;
1786+ moduleOp->walk ([&](mlir::omp::WorkdistributeOp workdistribute) {
1787+ bool res = VerifyTargetTeamsWorkdistribute (workdistribute);
1788+ if (!res)
1789+ signalPassFailure ();
1790+ });
17471791 moduleOp->walk ([&](mlir::omp::WorkdistributeOp workdistribute) {
17481792 changed |= FissionWorkdistribute (workdistribute);
17491793 });
0 commit comments