|
22 | 22 | namespace Fortran { |
23 | 23 | namespace lower { |
24 | 24 | namespace omp { |
| 25 | +bool DataSharingProcessor::OMPConstructSymbolVisitor::isSymbolDefineBy( |
| 26 | + const semantics::Symbol *symbol, lower::pft::Evaluation &eval) const { |
| 27 | + return eval.visit( |
| 28 | + common::visitors{[&](const parser::OpenMPConstruct &functionParserNode) { |
| 29 | + return symDefMap.count(symbol) && |
| 30 | + symDefMap.at(symbol) == &functionParserNode; |
| 31 | + }, |
| 32 | + [](const auto &functionParserNode) { return false; }}); |
| 33 | +} |
| 34 | + |
| 35 | +DataSharingProcessor::DataSharingProcessor( |
| 36 | + lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, |
| 37 | + const List<Clause> &clauses, lower::pft::Evaluation &eval, |
| 38 | + bool shouldCollectPreDeterminedSymbols, bool useDelayedPrivatization, |
| 39 | + lower::SymMap *symTable) |
| 40 | + : hasLastPrivateOp(false), converter(converter), semaCtx(semaCtx), |
| 41 | + firOpBuilder(converter.getFirOpBuilder()), clauses(clauses), eval(eval), |
| 42 | + shouldCollectPreDeterminedSymbols(shouldCollectPreDeterminedSymbols), |
| 43 | + useDelayedPrivatization(useDelayedPrivatization), symTable(symTable), |
| 44 | + visitor() { |
| 45 | + eval.visit([&](const auto &functionParserNode) { |
| 46 | + parser::Walk(functionParserNode, visitor); |
| 47 | + }); |
| 48 | +} |
25 | 49 |
|
26 | 50 | void DataSharingProcessor::processStep1( |
27 | 51 | mlir::omp::PrivateClauseOps *clauseOps, |
@@ -285,38 +309,9 @@ void DataSharingProcessor::collectSymbolsInNestedRegions( |
285 | 309 | // Recursively look for OpenMP constructs within `nestedEval`'s region |
286 | 310 | collectSymbolsInNestedRegions(nestedEval, flag, symbolsInNestedRegions); |
287 | 311 | else { |
288 | | - bool isOrderedConstruct = [&]() { |
289 | | - if (auto *ompConstruct = |
290 | | - nestedEval.getIf<parser::OpenMPConstruct>()) { |
291 | | - if (auto *ompBlockConstruct = |
292 | | - std::get_if<parser::OpenMPBlockConstruct>( |
293 | | - &ompConstruct->u)) { |
294 | | - const auto &beginBlockDirective = |
295 | | - std::get<parser::OmpBeginBlockDirective>( |
296 | | - ompBlockConstruct->t); |
297 | | - const auto origDirective = |
298 | | - std::get<parser::OmpBlockDirective>(beginBlockDirective.t).v; |
299 | | - |
300 | | - return origDirective == llvm::omp::Directive::OMPD_ordered; |
301 | | - } |
302 | | - } |
303 | | - |
304 | | - return false; |
305 | | - }(); |
306 | | - |
307 | | - bool isCriticalConstruct = [&]() { |
308 | | - if (auto *ompConstruct = |
309 | | - nestedEval.getIf<parser::OpenMPConstruct>()) { |
310 | | - return std::get_if<parser::OpenMPCriticalConstruct>( |
311 | | - &ompConstruct->u) != nullptr; |
312 | | - } |
313 | | - return false; |
314 | | - }(); |
315 | | - |
316 | | - if (!isOrderedConstruct && !isCriticalConstruct) |
317 | | - converter.collectSymbolSet(nestedEval, symbolsInNestedRegions, flag, |
318 | | - /*collectSymbols=*/true, |
319 | | - /*collectHostAssociatedSymbols=*/false); |
| 312 | + converter.collectSymbolSet(nestedEval, symbolsInNestedRegions, flag, |
| 313 | + /*collectSymbols=*/true, |
| 314 | + /*collectHostAssociatedSymbols=*/false); |
320 | 315 | } |
321 | 316 | } |
322 | 317 | } |
@@ -356,6 +351,11 @@ void DataSharingProcessor::collectSymbols( |
356 | 351 |
|
357 | 352 | llvm::SetVector<const semantics::Symbol *> symbolsInNestedRegions; |
358 | 353 | collectSymbolsInNestedRegions(eval, flag, symbolsInNestedRegions); |
| 354 | + |
| 355 | + for (auto *symbol : allSymbols) |
| 356 | + if (visitor.isSymbolDefineBy(symbol, eval)) |
| 357 | + symbolsInNestedRegions.remove(symbol); |
| 358 | + |
359 | 359 | // Filter-out symbols that must not be privatized. |
360 | 360 | bool collectImplicit = flag == semantics::Symbol::Flag::OmpImplicit; |
361 | 361 | bool collectPreDetermined = flag == semantics::Symbol::Flag::OmpPreDetermined; |
|
0 commit comments