diff --git a/src/occa/internal/lang/modes/serial.cpp b/src/occa/internal/lang/modes/serial.cpp index 83cc43073..c7d044f9a 100644 --- a/src/occa/internal/lang/modes/serial.cpp +++ b/src/occa/internal/lang/modes/serial.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -133,7 +134,7 @@ namespace occa { (smnt->type() & statementType::declaration) && ((declarationStatement*) smnt)->declaresVariable(var) ) { - defineExclusiveVariableAsArray(var); + defineExclusiveVariableAsArray((declarationStatement&) *smnt, var); return &varNode; } @@ -264,12 +265,122 @@ namespace occa { } } - void serialParser::defineExclusiveVariableAsArray(variable_t &var) { - // TODO: Dynamic array sizes - // Define the variable as a stack array + int serialParser::getInnerLoopLevel(forStatement &forSmnt) { + statement_t *smnt = forSmnt.up; + int level = 0; + while (smnt) { + if ((smnt->type() & statementType::for_) + && smnt->hasAttribute("inner")) { + ++level; + } + smnt = smnt->up; + } + return level; + } + + forStatement* serialParser::getInnerMostInnerLoop(forStatement &forSmnt) { + int maxLevel = -1; + forStatement *innerMostInnerLoop = NULL; + + statementArray::from(forSmnt) + .flatFilterByAttribute("inner") + .filterByStatementType(statementType::for_) + .forEach([&](statement_t *smnt) { + forStatement &innerSmnt = (forStatement&) *smnt; + const int level = getInnerLoopLevel(innerSmnt); + if (level > maxLevel) { + maxLevel = level; + innerMostInnerLoop = &innerSmnt; + } + }); + + return innerMostInnerLoop; + } + + void serialParser::defineExclusiveVariableAsArray(declarationStatement &declSmnt, + variable_t &var) { + // Find outer-most outer loop + statement_t *smnt = declSmnt.up; + forStatement *outerMostOuterLoop = NULL; + while (smnt) { + if (smnt->hasAttribute("outer")) { + outerMostOuterLoop = (forStatement*) smnt; + } + smnt = smnt->up; + } + + // Check if outer loop has max_inner_dims set + bool maxInnerDimsKnown{false}; + int maxInnerDims[3] = {1,1,1}; + if (outerMostOuterLoop->hasAttribute("max_inner_dims")) { + maxInnerDimsKnown = true; + attributeToken_t& attr = outerMostOuterLoop->attributes["max_inner_dims"]; + + for(size_t i=0; i < attr.args.size(); ++i) { + exprNode* expr = attr.args[i].expr; + primitive value = expr->evaluate(); + maxInnerDims[i] = value; + } + } + + //Check if inner dimensions are known at compile time + bool innerDimsKnown{true}; + int knownInnerDims[3] = {1,1,1}; + forStatement *innerSmnt = getInnerMostInnerLoop(*outerMostOuterLoop); + statementArray path = oklForStatement::getOklLoopPath(*innerSmnt); + + int innerIndex; + const int pathCount = (int) path.length(); + for (int i = 0; i < pathCount; ++i) { + forStatement &pathSmnt = *((forStatement*) path[i]); + oklForStatement oklForSmnt(pathSmnt); + + if(pathSmnt.hasAttribute("inner")) { + innerIndex = oklForSmnt.oklLoopIndex(); + if(oklForSmnt.getIterationCount()->canEvaluate()) { + knownInnerDims[innerIndex] = (int) oklForSmnt.getIterationCount()->evaluate(); + } else { + std::string s = oklForSmnt.getIterationCount()->toString(); + if(s.find("_occa_tiled_") != std::string::npos) { + size_t tile_size = s.find_first_of("123456789"); + OCCA_ERROR("@tile size is undefined!",tile_size != std::string::npos); + knownInnerDims[innerIndex] = std::stoi(s.substr(tile_size)); + } else { + //loop bounds are unknown at compile time + innerDimsKnown=false; + break; + } + } + } + } + const int knownInnerDim = knownInnerDims[0] + * knownInnerDims[1] + * knownInnerDims[2]; + const int maxInnerDim = maxInnerDims[0] + * maxInnerDims[1] + * maxInnerDims[2]; + + if (innerDimsKnown & maxInnerDimsKnown) { + if (knownInnerDim > maxInnerDim) { + outerMostOuterLoop->printError("[@inner] loop dimensions larger then allowed by [@max_inner_dims]"); + success=false; + return; + } + } + + // Determine how long the exclusive array should be + int exclusiveArraySize = 1024; + if (maxInnerDimsKnown) { + exclusiveArraySize = maxInnerDim; + } + if (innerDimsKnown) { + exclusiveArraySize = knownInnerDim; + } + + // Make exclusive variable declaration into an array // For example: // const int x - // -> const int x[256] + // -> const int x[1024] operatorToken startToken(var.source->origin, op::bracketStart); operatorToken endToken(var.source->origin, @@ -280,7 +391,7 @@ namespace occa { array_t(startToken, endToken, new primitiveNode(var.source, - 256)) + exclusiveArraySize)) ); } diff --git a/src/occa/internal/lang/modes/serial.hpp b/src/occa/internal/lang/modes/serial.hpp index ea24fe93c..29bc019f3 100644 --- a/src/occa/internal/lang/modes/serial.hpp +++ b/src/occa/internal/lang/modes/serial.hpp @@ -26,7 +26,12 @@ namespace occa { void setupExclusiveDeclaration(declarationStatement &declSmnt); void setupExclusiveIndices(); - void defineExclusiveVariableAsArray(variable_t &var); + int getInnerLoopLevel(forStatement &forSmnt); + + forStatement* getInnerMostInnerLoop(forStatement &forSmnt); + + void defineExclusiveVariableAsArray(declarationStatement &declSmnt, + variable_t &var); exprNode* addExclusiveVariableArrayAccessor(statement_t &smnt, exprNode &expr,