From 27fcd0a74a5eb23b77a6a2d488dcd19a84ab3abe Mon Sep 17 00:00:00 2001 From: darya-ver Date: Tue, 31 Jan 2023 14:10:00 -0800 Subject: [PATCH] New Feature: Halide Program IR Visualizer (#7056) Thanks for the feedback everyone! I will merge this into the ir-viz branch and work on it to get it ready for a PR into main. * initial commit * updates * added curr_loop_depth and changed throws for assert(false) * split into header/cpp file and added test file * adding changes to move to adobe laptop * added git ignore to ignore .vscode * attemping to get add_custom_lowering_pass() to work, not working yet * Can now compile to stmt_viz files * moved files into main src folder and added them to Makefile * fixed lesson_01_basics.cpp * pushed updates - very messy code * got side colors working and hiararchy tree. ready for code cleanup * cleaned up code. ready for split into .h/.cpp files * quick comment change * switched everything into .h/.cpp files * added CostPreProcessor class and removed def of mutate * removed definitions of mutate * added data movement costs and bar at the top * changed location of tooltip so it doesn't overflow left * updated cost function for laod/store based on vector size and type * updated colors of hiararchy tree * logic for deciding context of variables (messy) * cleaned up code. waiting for marcos * added context coloring. cleaned up code a bit * collapse/expand on hiararchy working * got depth expansion working for hiararchy * cleaned up code * cleaned up code and renamed some funcs/vars * fixed let hierarchy code and added down arrow to button * dependency graph stuff (still massive and busy) * Minor fixes -- please review * prod/cons built with if stmts and for loops * exit early if running on a module w/ >1 func * added var dependency button to mail html * fixed `add` benchmark and made error printing better * changed `m_assert` to `internal_error` * cleanedup dependency graph * added error for non concrete bounds in prodcons hier * made arrows change btwn up/down depending on sit. * fixed text for ConsProd tables to have strings * added logic for non-set bounds for for loop * added TODO * added syntax highlight to strings and ints * added dotdotdot logic for collapsed children * fixed small bug where 2nd tree wasn't starting correctly * changed colors of ... nodes based on parent color * added if flowchart * added bools for printing different HTML parts of code * added different background colors per object * cleaned up borders of objects * fixed prodCons spacing and started allocate logic * removed border for ifthenelse table * implemented anchoring for prodcons tables * fixed empty if-stmts * open and close anchor are now right after one another * added filename logic for anchors and add blocks for func args * pass in FindStmtCost instead of reruning traversal * fixed comment * heatmap for prodConsHierarchy * fixed consume values (i think) and changed block colors * fixed allocate "!is_const_one(op->condition" error * fixed StmtSizes::visit(const For *op) (variables were Add) * removed nested-ifs logic (edge case we don't have to worry about) * changed table headers to only show loop interations and no bubbling up * get unique values for loads with ramp only * fixed !is_const_one(op->condition) in Allocate * changed allocate table to Dim-1, Dim-2, etc * (1) store: changed cost (2) load: added global/local (3) allocate: vized memtype (4) prodConsTable: changed to read/write * BOOTSTAP! added navigation pannel at top * line numbers!!! and removed tooltip (for now) * changed style of info buttons * adjusted and added icons for see-code and info buttons * removed a comment * condensed cost color classes * fixed ifthenelse line numbers * fixed if if-else anchor names * changed prodCons from table to div * adding cost colors for prodConsViz to the left side of div * made long conditions "..." in ProdConsViz * adding spacing for prodCons start viz and dependency graph viz * tooltip!!!! (still a little ugly, but functioning!!) * removed tooltip arrow and changed background to white * removed arrow for tooltip and added if-stmt condition tooltip * added more tooltips to prodCons * getStmtHierarchy popup implemented :) !! * moved css var definitions to respective files * added bubble_up() and multiple modules * converted some stringstream to string + reordered module functions in viz * added getStmtHierarchy js working for expanding/collapsing * calculate color ranges once and not every time * should be added to previous commit * changed everything from IRMutator -> IRVisitor * side by side view on main page * added expand code / viz buttons functionality * attempting to switch GetStmtHierarchy to 1 tree with colors on side * should revert this change later, but need to for now (merging with main) * added Reinterpret + fixed double graph in StmtHierarchy * removed omg!!!! for reinterpret * changed border colors of stmtHierarchy + removed print statements * visualized assert + added colors to assert + made all info buttons next to colors * added resize bar * removed navigation code (sticking to 2col layout) * changed colors spans to buttons (removed segfault???) * visualize entire LetStmt and cleaned up GetStmtHierarchy.cpp * added more info in info-buttons * added hover to side colors in stmthierarchy * made collapse buttons resizeBar icons + put see_code_button top right of div * (1) added code to viz buttons (2) display all if-stmts, even if they are empty (3) fixed store highlight cost span * added scrollTo for function buttons within modules * changed info-button style * (1) added hover over for colors in stmtHierarchy (2) removed =default constructor/destructor (3) changed getStmtHierarchy to string html instead of stringstream * made sure updated StmtToHtml code was in StmtToViz code * small style changes * added see code/viz buttons for module functions * (1) loop size -> loop span (2) made function names big in viz (3) load types in name * removed inline style tags * added VectorReduce code for stmt hierarchy * removed scopeName hack to fix previous scope error (hope it's not happening anymore) * fixed scrollTo if code is hidden * removed commented out includes * changed costs to inclusive vs exclusive (still might be a bit broken) * removed print statements * fixed loop_depth = 0 error * (1) tooltips include inclusive and exclusive sizes (2) moved tooltip HTML to FindStmtCost.cpp * reworked tooltip style * (1) made getStmtHierarchy *exclusive* costs (2) viuslizing costs for IfThenElse blocks * visualize For and ProducerConsumer blocks * got collapse of code to show cumulative color cost * removed context span button * removed bubble up code and associated logic (now only read/write for loads and stores) * change some variables to read/write instead of prod/cons * fixed range bug * dense/strided vector load * removed inline TODO comments * added loop var for for loops * made loading MUCHHHH faster!!!! * (1) fixed function box width in viz (2) fixed collapse/expand button for functions in code * compile assembly if stmt_viz flag is given * starting assembly stuff * got assembly code button working (button is still ugly) * (1) made assembly button prettier (2) started information bar at top (need to fill in content of info popup) * added content to information bar button popup * (1) fixed if statement costs (2) added percentages to cost tables instead of values * removed output_file_name from ProducerConsumerHierarchy.h and related code in StmtToViz.cpp * fixed IfThenElse cost if there are nested ifs * removed dependency graph logic and files * made tooltip table input vector of pairs so that we can specify order * made tooltip table input vector of pairs so that we can specify order * added collapse/expand to viz on right * (1) collapseCode works now (2) search works in assembly tab * changed codemirror to ARM assembly highlighting * start of refactor: commenting and cleanup * fixed bug!!!!! i think. i hope !!! * removed Stmt function versions (never run this code on Stmt input, only module) * changed ProdCons stuff to IRViz * removed print line for strided vectors (seems to be working now) * fixed bug!!!!!! changed things back to stringstream, because that wasn't the issue * have helper functions return strings isntead of being void * end of refactor (for now) - changed variables from camelCase to snake_case * fix ... error for collapsing nodes * fixed div issue + tooltips not being correct location * added error message for multiple modulse (doesn't currently support) * fixed spacing for boxBody divs * removed submodules logic because it's not supported right now * made assembly marker generation more accurate (added counters to have marker names be unique) * got 3 columns resizing mostly working (just a little glitchy, good enough for now) * got assembly button to populate assembly, kind of working * added assemblyInfoViz.h to makefile * fixed resize bars for 3 different visualizations * fixed linewrapping issue with codemirror * updated spacing for IRVisualization buttons in header * (1) fixed functionBox button sizing (2) dense vector load -> [Dense, Vector] load * fixed informationBar spacing * updated InformationBar content * removed current_loop_depth from consideration of cost * changed cost table tooltip: inclusive: show %, exclusive: show raw cost * simplified cost model * (1) updated InformationBar w/ info for assembly (2) added assembly by default to third col * added logic to collapseVizAssembly if curson passes resizeBar * moved all color range + tooltip logic into IRVisualization * fixed get_combined_color_range() error * reordered js/css strings * changed format and slighly changed content of cost tooltips * refactor: .h and .cpp files have same order * refactor: added comments * refactor: updated internal_error messages * fixed small import / #ifndef typo * updated get_loop_iterator to include more binary ops for extent * reverting some changes I made to get ready for PR * adding CMake build * refactoring namespace scoping * refactoring "endl" * refactoring header guards and includes * const vector reference * ostringstream all the things! * having a symbol for "canIgnoreVariableName" * string -> char*, with raw string literal * internal_error -> internal_assert() * clang-format * if-else chain to switch-case block * Upgrade wabt to 1.0.30 (#7058) * Add support for float16 buffer in python extension (#7060) * run clang-tidy and clang-format * run clang-tidy & clang-format * run clang-tidy and clang-format, again * run clang-tidy and clang-format, Phaze III * Minor PR Revision - If `stmt_viz` flag is used without the `assembly` flag, the compiler throws an error. - GetAssemblyInfoViz.cpp: replace regex with replace_all - GetSttmtHierarchy.cpp: Bug fix (line 721). Use raw strings for large literals - Restricted scope of default statement values - Added enum type for StmtCostModel. A single cost model config value is specified, instead of multiple booleans. * reminder for later --------- Co-authored-by: Darya Verzhbinsky Co-authored-by: Maaz Ahmad Co-authored-by: Marcos Slomp Co-authored-by: Steven Johnson Co-authored-by: Steve Suzuki Co-authored-by: Marcos Slomp --- Makefile | 10 + src/CMakeLists.txt | 10 + src/CodeGen_LLVM.cpp | 14 +- src/CodeGen_LLVM.h | 4 + src/FindStmtCost.cpp | 1036 ++++++++++++++++ src/FindStmtCost.h | 155 +++ src/Generator.cpp | 9 +- src/GetAssemblyInfoViz.cpp | 199 +++ src/GetAssemblyInfoViz.h | 91 ++ src/GetStmtHierarchy.cpp | 728 +++++++++++ src/GetStmtHierarchy.h | 130 ++ src/IRVisualization.cpp | 1613 ++++++++++++++++++++++++ src/IRVisualization.h | 186 +++ src/Module.cpp | 8 + src/Module.h | 1 + src/StmtToViz.cpp | 2377 ++++++++++++++++++++++++++++++++++++ src/StmtToViz.h | 31 + 17 files changed, 6597 insertions(+), 5 deletions(-) create mode 100644 src/FindStmtCost.cpp create mode 100644 src/FindStmtCost.h create mode 100644 src/GetAssemblyInfoViz.cpp create mode 100644 src/GetAssemblyInfoViz.h create mode 100644 src/GetStmtHierarchy.cpp create mode 100644 src/GetStmtHierarchy.h create mode 100644 src/IRVisualization.cpp create mode 100644 src/IRVisualization.h create mode 100644 src/StmtToViz.cpp create mode 100644 src/StmtToViz.h diff --git a/Makefile b/Makefile index 2796387233dc..ffe326d8592d 100644 --- a/Makefile +++ b/Makefile @@ -467,6 +467,7 @@ SOURCE_FILES = \ FastIntegerDivide.cpp \ FindCalls.cpp \ FindIntrinsics.cpp \ + FindStmtCost.cpp \ FlattenNestedRamps.cpp \ Float16.cpp \ Func.cpp \ @@ -474,6 +475,8 @@ SOURCE_FILES = \ FuseGPUThreadLoops.cpp \ FuzzFloatStores.cpp \ Generator.cpp \ + GetAssemblyInfoViz.cpp \ + GetStmtHierarchy.cpp \ HexagonOffload.cpp \ HexagonOptimize.cpp \ ImageParam.cpp \ @@ -491,6 +494,7 @@ SOURCE_FILES = \ IROperator.cpp \ IRPrinter.cpp \ IRVisitor.cpp \ + IRVisualization.cpp \ JITModule.cpp \ Lambda.cpp \ Lerp.cpp \ @@ -561,6 +565,7 @@ SOURCE_FILES = \ SpirvIR.cpp \ SplitTuples.cpp \ StmtToHtml.cpp \ + StmtToViz.cpp \ StorageFlattening.cpp \ StorageFolding.cpp \ StrictifyFloat.cpp \ @@ -647,6 +652,7 @@ HEADER_FILES = \ FastIntegerDivide.h \ FindCalls.h \ FindIntrinsics.h \ + FindStmtCost.h \ FlattenNestedRamps.h \ Float16.h \ Func.h \ @@ -655,6 +661,8 @@ HEADER_FILES = \ FuseGPUThreadLoops.h \ FuzzFloatStores.h \ Generator.h \ + GetAssemblyInfoViz.h \ + GetStmtHierarchy.h \ HexagonOffload.h \ HexagonOptimize.h \ ImageParam.h \ @@ -673,6 +681,7 @@ HEADER_FILES = \ IROperator.h \ IRPrinter.h \ IRVisitor.h \ + IRVisualization.h \ WasmExecutor.h \ JITModule.h \ Lambda.h \ @@ -727,6 +736,7 @@ HEADER_FILES = \ Solve.h \ SplitTuples.h \ StmtToHtml.h \ + StmtToViz.h \ StorageFlattening.h \ StorageFolding.h \ StrictifyFloat.h \ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cc9f6805ba4a..07fd8428cdba 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -67,6 +67,7 @@ set(HEADER_FILES FastIntegerDivide.h FindCalls.h FindIntrinsics.h + FindStmtCost.h FlattenNestedRamps.h Float16.h Func.h @@ -75,6 +76,8 @@ set(HEADER_FILES FuseGPUThreadLoops.h FuzzFloatStores.h Generator.h + GetAssemblyInfoViz.h + GetStmtHierarchy.h HexagonOffload.h HexagonOptimize.h ImageParam.h @@ -93,6 +96,7 @@ set(HEADER_FILES IROperator.h IRPrinter.h IRVisitor.h + IRVisualization.h JITModule.h Lambda.h Lerp.h @@ -146,6 +150,7 @@ set(HEADER_FILES Solve.h SplitTuples.h StmtToHtml.h + StmtToViz.h StorageFlattening.h StorageFolding.h StrictifyFloat.h @@ -231,6 +236,7 @@ set(SOURCE_FILES FastIntegerDivide.cpp FindCalls.cpp FindIntrinsics.cpp + FindStmtCost.cpp FlattenNestedRamps.cpp Float16.cpp Func.cpp @@ -238,6 +244,8 @@ set(SOURCE_FILES FuseGPUThreadLoops.cpp FuzzFloatStores.cpp Generator.cpp + GetAssemblyInfoViz.cpp + GetStmtHierarchy.cpp HexagonOffload.cpp HexagonOptimize.cpp ImageParam.cpp @@ -255,6 +263,7 @@ set(SOURCE_FILES IROperator.cpp IRPrinter.cpp IRVisitor.cpp + IRVisualization.cpp JITModule.cpp Lambda.cpp Lerp.cpp @@ -325,6 +334,7 @@ set(SOURCE_FILES SpirvIR.cpp SplitTuples.cpp StmtToHtml.cpp + StmtToViz.cpp StorageFlattening.cpp StorageFolding.cpp StrictifyFloat.cpp diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index d09200591f85..2786ef2e7d9c 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -3666,11 +3666,13 @@ void CodeGen_LLVM::return_with_error_code(llvm::Value *error_code) { } void CodeGen_LLVM::visit(const ProducerConsumer *op) { + producer_consumer_count++; + string name; if (op->is_producer) { - name = std::string("produce ") + op->name; + name = std::to_string(producer_consumer_count) + std::string("_produce ") + op->name; } else { - name = std::string("consume ") + op->name; + name = std::to_string(producer_consumer_count) + std::string("_consume ") + op->name; } BasicBlock *produce = BasicBlock::Create(*context, name, function); builder->CreateBr(produce); @@ -3679,6 +3681,8 @@ void CodeGen_LLVM::visit(const ProducerConsumer *op) { } void CodeGen_LLVM::visit(const For *op) { + for_loop_count++; + Value *min = codegen(op->min); Value *extent = codegen(op->extent); const Acquire *acquire = op->body.as(); @@ -3696,9 +3700,11 @@ void CodeGen_LLVM::visit(const For *op) { BasicBlock *preheader_bb = builder->GetInsertBlock(); // Make a new basic block for the loop - BasicBlock *loop_bb = BasicBlock::Create(*context, std::string("for ") + op->name, function); + BasicBlock *loop_bb = BasicBlock::Create( + *context, std::to_string(for_loop_count) + std::string("_for ") + op->name, function); // Create the block that comes after the loop - BasicBlock *after_bb = BasicBlock::Create(*context, std::string("end for ") + op->name, function); + BasicBlock *after_bb = BasicBlock::Create( + *context, std::to_string(for_loop_count) + std::string("_end for ") + op->name, function); // If min < max, fall through to the loop bb Value *enter_condition = builder->CreateICmpSLT(min, max); diff --git a/src/CodeGen_LLVM.h b/src/CodeGen_LLVM.h index 79797950cb7b..7720047970af 100644 --- a/src/CodeGen_LLVM.h +++ b/src/CodeGen_LLVM.h @@ -627,6 +627,10 @@ class CodeGen_LLVM : public IRVisitor { // @} private: + // used for mapping IR nodes to llvm markers in StmtToViz.cpp + int producer_consumer_count = 0; + int for_loop_count = 0; + /** All the values in scope at the current code location during * codegen. Use sym_push and sym_pop to access. */ Scope symbol_table; diff --git a/src/FindStmtCost.cpp b/src/FindStmtCost.cpp new file mode 100644 index 000000000000..90a4a58d757b --- /dev/null +++ b/src/FindStmtCost.cpp @@ -0,0 +1,1036 @@ +#include "FindStmtCost.h" +#include "StmtToViz.h" + +using namespace Halide; +using namespace Internal; + +namespace Halide { +namespace Internal { + +using std::ostringstream; +using std::string; +using std::vector; + +/* + * FindStmtCost class + */ +void FindStmtCost::generate_costs(const Module &m) { + traverse(m); + set_max_costs(m); +} + +int FindStmtCost::get_cost(const IRNode *node, StmtCostModel cost_model) const { + if (node->node_type == IRNodeType::IfThenElse) { + return get_if_node_cost(static_cast(node), cost_model); + } else { + switch (cost_model) { + case StmtCostModel::Compute: + return get_computation_cost(node, false); + case StmtCostModel::ComputeInclusive: + return get_computation_cost(node, true); + case StmtCostModel::DataMovement: + return get_data_movement_cost(node, false); + case StmtCostModel::DataMovementInclusive: + return get_data_movement_cost(node, true); + default: + internal_assert(false) << "\n" + << "FindStmtCost::get_cost doest not recognize the cost model:" + << cost_model + << "\n\n"; + return -1; + } + } +} + +int FindStmtCost::get_depth(const IRNode *node) const { + internal_assert(node != nullptr) << "\n" + << "FindStmtCost::get_depth: node is nullptr" + << "\n\n"; + + auto it = stmt_cost.find(node); + if (it == stmt_cost.end()) { + + // TODO(marcos): on the next revision, make sure to further elaborate on + // the comment below, with full sentences. + // sometimes, these constant values are created on the whim in + // StmtToViz.cpp - return 1 to avoid crashing + IRNodeType type = node->node_type; + if (type == IRNodeType::IntImm || type == IRNodeType::UIntImm || + type == IRNodeType::FloatImm || type == IRNodeType::StringImm) { + return 1; + } + + // this happens when visualizing cost of else_case in StmtToViz.cpp + else if (type == IRNodeType::IfThenElse) { + Stmt then_case = ((const IfThenElse *)node)->then_case; + return get_depth(then_case.get()); + } + + else { + internal_assert(false) << "\n" + << "FindStmtCost::get_depth: " << print_node(node) + << "node not found in stmt_cost map" + << "\n\n"; + return 0; + } + } + + return it->second.depth; +} + +int FindStmtCost::get_max_cost(StmtCostModel cost_model) const { + switch (cost_model) { + case StmtCostModel::Compute: + return max_computation_cost_exclusive; + case StmtCostModel::ComputeInclusive: + return max_computation_cost_inclusive; + case StmtCostModel::DataMovement: + return max_data_movement_cost_inclusive; + case StmtCostModel::DataMovementInclusive: + return max_data_movement_cost_exclusive; + default: + internal_assert(false) << "\n" + << "FindStmtCost::get_max_cost doest not recognize the cost model:" + << cost_model + << "\n\n"; + return -1; + } +} + +void FindStmtCost::traverse(const Module &m) { + + // traverse all functions + for (const auto &f : m.functions()) { + f.body.accept(this); + } +} + +int FindStmtCost::get_computation_cost(const IRNode *node, bool inclusive) const { + internal_assert(node != nullptr) << "\n" + << "FindStmtCost::get_computation_cost: node is nullptr" + << "\n\n"; + + auto it = stmt_cost.find(node); + IRNodeType type = node->node_type; + int cost = -1; + + if (it == stmt_cost.end()) { + // TODO(marcos): on the next revision, make sure to further elaborate on + // the comment below, with full sentences. + // sometimes, these constant values are created on the whim in + // StmtToViz.cpp - set cost_node to be fresh StmtCost to avoid crashing + if (type == IRNodeType::IntImm || type == IRNodeType::UIntImm || + type == IRNodeType::FloatImm || type == IRNodeType::StringImm) { + cost = StmtCost::NormalNodeCC; + } + + // this happens when visualizing cost of else_case in StmtToViz.cpp + else if (type == IRNodeType::Variable) { + const Variable *var = (const Variable *)node; + if (var->name == StmtToViz_canIgnoreVariableName_string) { + cost = StmtCost::NormalNodeCC; + } + } + + else { + internal_assert(false) << "\n" + << "FindStmtCost::get_computation_cost: " << print_node(node) + << "node not found in stmt_cost map" + << "\n\n"; + return 0; + } + } else { + if (inclusive) { + cost = it->second.computation_cost_inclusive; + } else { + cost = it->second.computation_cost_exclusive; + } + } + + internal_assert(cost >= 0) << "\n" + << "FindStmtCost::get_computation_cost: " << print_node(node) + << "computation_cost_exclusive not set (cost is: " << cost << ")" + << "\n\n"; + + return cost; +} +int FindStmtCost::get_data_movement_cost(const IRNode *node, bool inclusive) const { + internal_assert(node != nullptr) << "\n" + << "FindStmtCost::get_data_movement_cost: node is nullptr" + << "\n\n"; + + auto it = stmt_cost.find(node); + IRNodeType type = node->node_type; + int cost = -1; + + if (it == stmt_cost.end()) { + // sometimes, these constant values are created on the whim in + // StmtToViz.cpp - set cost_node to be fresh StmtCost to avoid crashing + if (type == IRNodeType::IntImm || type == IRNodeType::UIntImm || + type == IRNodeType::FloatImm || type == IRNodeType::StringImm) { + cost = StmtCost::NormalNodeDMC; + } + + // this happens when visualizing cost of else_case in StmtToViz.cpp + else if (type == IRNodeType::Variable) { + const Variable *var = (const Variable *)node; + if (var->name == StmtToViz_canIgnoreVariableName_string) { + cost = StmtCost::NormalNodeDMC; + } + } else { + internal_assert(false) << "\n" + << "FindStmtCost::get_data_movement_cost: " << print_node(node) + << "node not found in stmt_cost map" + << "\n\n"; + return 0; + } + } else { + if (inclusive) { + cost = it->second.data_movement_cost_inclusive; + } else { + cost = it->second.data_movement_cost_exclusive; + } + } + + internal_assert(cost >= 0) << "\n" + << "FindStmtCost::get_data_movement_cost: " << print_node(node) + << "data_movement_cost_exclusive not set (cost is: " << cost << ")" + << "\n\n"; + + return cost; +} + +int FindStmtCost::get_if_node_cost(const IfThenElse *if_then_else, StmtCostModel cost_model) const { + switch (cost_model) { + case StmtCostModel::Compute: + return StmtCost::NormalNodeCC; + case StmtCostModel::ComputeInclusive: + return get_computation_cost(if_then_else->condition.get(), true) + + get_computation_cost(if_then_else->then_case.get(), true); + case StmtCostModel::DataMovement: + return StmtCost::NormalNodeDMC; + case StmtCostModel::DataMovementInclusive: + return get_data_movement_cost(if_then_else->condition.get(), true) + + get_data_movement_cost(if_then_else->then_case.get(), true); + default: + internal_assert(false) << "\n" + << "FindStmtCost::get_if_node_cost doest not recognize the cost model:" + << cost_model + << "\n\n"; + return -1; + } +} + +vector FindStmtCost::get_costs_children(const IRNode *parent, const vector &children, + bool inclusive) const { + int children_cc = 0; + int children_dmc = 0; + + for (const IRNode *child : children) { + children_cc += get_computation_cost(child, inclusive); + children_dmc += get_data_movement_cost(child, inclusive); + } + + vector costs_children = {children_cc, children_dmc}; + + return costs_children; +} + +void FindStmtCost::set_costs( + bool inclusive, const IRNode *node, const vector &children, + const std::function &calculate_cc = [](int x) { return StmtCost::NormalNodeCC + x; }, + const std::function &calculate_dmc = [](int x) { return StmtCost::NormalNodeDMC + x; }) { + + vector costs_children = get_costs_children(node, children, inclusive); + + int computation_cost; + int data_movement_cost; + computation_cost = calculate_cc(costs_children[0]); + data_movement_cost = calculate_dmc(costs_children[1]); + + auto it = stmt_cost.find(node); + if (it == stmt_cost.end()) { + if (inclusive) { + stmt_cost.emplace( + node, StmtCost{current_loop_depth, computation_cost, data_movement_cost, -1, -1}); + } else { + stmt_cost.emplace( + node, StmtCost{current_loop_depth, -1, -1, computation_cost, data_movement_cost}); + } + } else { + if (inclusive) { + it->second.computation_cost_inclusive = computation_cost; + it->second.data_movement_cost_inclusive = data_movement_cost; + } else { + it->second.computation_cost_exclusive = computation_cost; + it->second.data_movement_cost_exclusive = data_movement_cost; + } + } +} + +void FindStmtCost::set_max_costs(const Module &m) { + + // inclusive costs (sum up all costs of bodies of functions in module) + int body_computation_cost = 0; + int body_data_movement_cost = 0; + for (const auto &f : m.functions()) { + body_computation_cost += get_computation_cost(f.body.get(), true); + body_data_movement_cost += get_data_movement_cost(f.body.get(), true); + } + + max_computation_cost_inclusive = body_computation_cost; + max_data_movement_cost_inclusive = body_data_movement_cost; + + // max_computation_cost_exclusive + int max_cost = 0; + for (auto const &pair : stmt_cost) { + int cost = pair.second.computation_cost_exclusive; + if (cost > max_cost) { + max_cost = cost; + } + } + max_computation_cost_exclusive = max_cost; + + // max_data_movement_cost_exclusive + max_cost = 0; + for (auto const &pair : stmt_cost) { + int cost = pair.second.data_movement_cost_exclusive; + if (cost > max_cost) { + max_cost = cost; + } + } + max_data_movement_cost_exclusive = max_cost; +} + +int FindStmtCost::get_scaling_factor(uint8_t bits, uint16_t lanes) const { + int bits_factor = bits / 8; + int lanes_factor = lanes / 8; + + if (bits_factor == 0) { + bits_factor = 1; + } + if (lanes_factor == 0) { + lanes_factor = 1; + } + return bits_factor * lanes_factor; +} + +void FindStmtCost::visit_binary_op(const IRNode *op, const Expr &a, const Expr &b) { + a.accept(this); + b.accept(this); + + // inclusive and exclusive costs are the same + set_costs(true, op, {a.get(), b.get()}); + set_costs(false, op, {a.get(), b.get()}); +} + +void FindStmtCost::visit(const IntImm *op) { + set_costs(true, op, {}); + set_costs(false, op, {}); +} + +void FindStmtCost::visit(const UIntImm *op) { + set_costs(true, op, {}); + set_costs(false, op, {}); +} + +void FindStmtCost::visit(const FloatImm *op) { + set_costs(true, op, {}); + set_costs(false, op, {}); +} + +void FindStmtCost::visit(const StringImm *op) { + set_costs(true, op, {}); + set_costs(false, op, {}); +} + +void FindStmtCost::visit(const Cast *op) { + op->value.accept(this); + + // inclusive and exclusive costs are the same + set_costs(true, op, {op->value.get()}); + set_costs(false, op, {op->value.get()}); +} + +void FindStmtCost::visit(const Reinterpret *op) { + op->value.accept(this); + + // inclusive and exclusive costs are the same + set_costs(true, op, {op->value.get()}); + set_costs(false, op, {op->value.get()}); +} +void FindStmtCost::visit(const Variable *op) { + set_costs(true, op, {}); + set_costs(false, op, {}); +} + +void FindStmtCost::visit(const Add *op) { + visit_binary_op(op, op->a, op->b); +} + +void FindStmtCost::visit(const Sub *op) { + visit_binary_op(op, op->a, op->b); +} + +void FindStmtCost::visit(const Mul *op) { + visit_binary_op(op, op->a, op->b); +} + +void FindStmtCost::visit(const Div *op) { + visit_binary_op(op, op->a, op->b); +} + +void FindStmtCost::visit(const Mod *op) { + visit_binary_op(op, op->a, op->b); +} + +void FindStmtCost::visit(const Min *op) { + visit_binary_op(op, op->a, op->b); +} + +void FindStmtCost::visit(const Max *op) { + visit_binary_op(op, op->a, op->b); +} + +void FindStmtCost::visit(const EQ *op) { + visit_binary_op(op, op->a, op->b); +} + +void FindStmtCost::visit(const NE *op) { + visit_binary_op(op, op->a, op->b); +} + +void FindStmtCost::visit(const LT *op) { + visit_binary_op(op, op->a, op->b); +} + +void FindStmtCost::visit(const LE *op) { + visit_binary_op(op, op->a, op->b); +} + +void FindStmtCost::visit(const GT *op) { + visit_binary_op(op, op->a, op->b); +} + +void FindStmtCost::visit(const GE *op) { + visit_binary_op(op, op->a, op->b); +} + +void FindStmtCost::visit(const And *op) { + visit_binary_op(op, op->a, op->b); +} + +void FindStmtCost::visit(const Or *op) { + visit_binary_op(op, op->a, op->b); +} + +void FindStmtCost::visit(const Not *op) { + op->a.accept(this); + + // inclusive and exclusive costs are the same + set_costs(true, op, {op->a.get()}); + set_costs(false, op, {op->a.get()}); +} + +void FindStmtCost::visit(const Select *op) { + op->condition.accept(this); + op->true_value.accept(this); + op->false_value.accept(this); + + // inclusive and exclusive costs are the same + set_costs(true, op, {op->condition.get(), op->true_value.get(), op->false_value.get()}); + set_costs(false, op, {op->condition.get(), op->true_value.get(), op->false_value.get()}); +} + +void FindStmtCost::visit(const Load *op) { + op->predicate.accept(this); + op->index.accept(this); + + uint8_t bits = op->type.bits(); + uint16_t lanes = op->type.lanes(); + int scaling_factor = get_scaling_factor(bits, lanes); + + std::function calculate_cc = [scaling_factor](int children_cost) { + return scaling_factor * (StmtCost::NormalNodeCC + children_cost); + }; + + std::function calculate_dmc = [scaling_factor](int children_cost) { + return scaling_factor * (StmtCost::LoadDMC + children_cost); + }; + + // inclusive and exclusive costs are the same + set_costs(true, op, {op->predicate.get(), op->index.get()}, calculate_cc, calculate_dmc); + set_costs(false, op, {op->predicate.get(), op->index.get()}, calculate_cc, calculate_dmc); +} + +void FindStmtCost::visit(const Ramp *op) { + op->base.accept(this); + op->stride.accept(this); + + // inclusive and exclusive costs are the same + set_costs(true, op, {op->base.get(), op->stride.get()}); + set_costs(false, op, {op->base.get(), op->stride.get()}); +} + +void FindStmtCost::visit(const Broadcast *op) { + op->value.accept(this); + + // inclusive and exclusive costs are the same + set_costs(true, op, {op->value.get()}); + set_costs(false, op, {op->value.get()}); +} + +void FindStmtCost::visit(const Call *op) { + vector children; + + for (const auto &arg : op->args) { + arg.accept(this); + children.push_back(arg.get()); + } + + // Consider extern call args + if (op->func.defined()) { + Function f(op->func); + if (op->call_type == Call::Halide && f.has_extern_definition()) { + for (const auto &arg : f.extern_arguments()) { + if (arg.is_expr()) { + arg.expr.accept(this); + children.push_back(arg.expr.get()); + } + } + } + } + + // inclusive and exclusive costs are the same + set_costs(true, op, children); + set_costs(false, op, children); +} + +void FindStmtCost::visit(const Let *op) { + + op->value.accept(this); + op->body.accept(this); + + // inclusive and exclusive costs are the same (keep body in both since it's all inlined) + set_costs(true, op, {op->value.get(), op->body.get()}); + set_costs(false, op, {op->value.get(), op->body.get()}); +} + +void FindStmtCost::visit(const Shuffle *op) { + vector children; + for (const Expr &i : op->vectors) { + i.accept(this); + children.push_back(i.get()); + } + + // inclusive and exclusive costs are the same + set_costs(true, op, children); + set_costs(false, op, children); +} + +void FindStmtCost::visit(const VectorReduce *op) { + op->value.accept(this); + + // represents the number of times the op->op is applied to the vector + int count_cost = op->value.type().lanes() - 1; + + std::function calculate_cc = [count_cost](int children_cost) { + return count_cost * (StmtCost::NormalNodeCC + children_cost); + }; + + std::function calculate_dmc = [count_cost](int children_cost) { + return count_cost * (StmtCost::NormalNodeDMC + children_cost); + }; + + // inclusive and exclusive costs are the same + set_costs(true, op, {op->value.get()}, calculate_cc, calculate_dmc); + set_costs(false, op, {op->value.get()}, calculate_cc, calculate_dmc); +} + +void FindStmtCost::visit(const LetStmt *op) { + op->value.accept(this); + op->body.accept(this); + + set_costs(true, op, {op->value.get(), op->body.get()}); + set_costs(false, op, {op->value.get()}); +} + +void FindStmtCost::visit(const AssertStmt *op) { + op->condition.accept(this); + op->message.accept(this); + + // inclusive and exclusive costs are the same + set_costs(true, op, {op->condition.get(), op->message.get()}); + set_costs(false, op, {op->condition.get(), op->message.get()}); +} + +void FindStmtCost::visit(const ProducerConsumer *op) { + op->body.accept(this); + + set_costs(true, op, {op->body.get()}); + set_costs(false, op, {}); +} + +void FindStmtCost::visit(const For *op) { + current_loop_depth += 1; + + op->min.accept(this); + op->extent.accept(this); + op->body.accept(this); + + current_loop_depth -= 1; + + set_costs(true, op, {op->min.get(), op->extent.get(), op->body.get()}); + set_costs(false, op, {op->min.get(), op->extent.get()}); + + // TODO: complete implementation of different loop types + if (op->for_type == ForType::Parallel) { + internal_assert(false) << "\n" + << "FindStmtCost::visit: Parallel for loops are not supported yet" + << "\n\n"; + } + if (op->for_type == ForType::Unrolled) { + internal_assert(false) << "\n" + << "FindStmtCost::visit: Unrolled for loops are not supported yet" + << "\n\n"; + } + if (op->for_type == ForType::Vectorized) { + internal_assert(false) << "\n" + << "FindStmtCost::visit: Vectorized for loops are not supported yet" + << "\n\n"; + } +} + +void FindStmtCost::visit(const Acquire *op) { + ostringstream name; + name << op->semaphore; + + op->semaphore.accept(this); + op->count.accept(this); + op->body.accept(this); + + set_costs(true, op, {op->semaphore.get(), op->count.get(), op->body.get()}); + set_costs(false, op, {op->semaphore.get(), op->count.get()}); +} + +void FindStmtCost::visit(const Store *op) { + + op->predicate.accept(this); + op->index.accept(this); + op->value.accept(this); + + std::function calculate_cc = [](int children_cost) { + return StmtCost::NormalNodeCC + children_cost; + }; + + std::function calculate_dmc = [](int children_cost) { + return StmtCost::StoreDMC + children_cost; + }; + + // inclusive and exclusive costs are the same + set_costs(true, op, {op->predicate.get(), op->index.get(), op->value.get()}, calculate_cc, + calculate_dmc); + set_costs(false, op, {op->predicate.get(), op->index.get(), op->value.get()}, calculate_cc, + calculate_dmc); +} + +void FindStmtCost::visit(const Provide *op) { + op->predicate.accept(this); + + vector children; + children.push_back(op->predicate.get()); + + for (const auto &value : op->values) { + value.accept(this); + children.push_back(value.get()); + } + for (const auto &arg : op->args) { + arg.accept(this); + children.push_back(arg.get()); + } + + // inclusive and exclusive costs are the same + set_costs(true, op, children); + set_costs(false, op, children); +} + +void FindStmtCost::visit(const Allocate *op) { + vector children; + + for (const auto &extent : op->extents) { + extent.accept(this); + children.push_back(extent.get()); + } + + op->condition.accept(this); + children.push_back(op->condition.get()); + + if (op->new_expr.defined()) { + op->new_expr.accept(this); + children.push_back(op->new_expr.get()); + } + + set_costs(false, op, children); + + op->body.accept(this); + children.push_back(op->body.get()); + + set_costs(true, op, children); +} + +void FindStmtCost::visit(const Free *op) { + set_costs(true, op, {}); + set_costs(false, op, {}); +} + +void FindStmtCost::visit(const Realize *op) { + vector children; + + for (const auto &bound : op->bounds) { + bound.min.accept(this); + bound.extent.accept(this); + children.push_back(bound.min.get()); + children.push_back(bound.extent.get()); + } + + op->condition.accept(this); + children.push_back(op->condition.get()); + + set_costs(false, op, children); + + op->body.accept(this); + children.push_back(op->body.get()); + + set_costs(true, op, children); +} + +void FindStmtCost::visit(const Prefetch *op) { + vector children; + + for (const auto &bound : op->bounds) { + bound.min.accept(this); + bound.extent.accept(this); + + children.push_back(bound.min.get()); + children.push_back(bound.extent.get()); + } + + op->condition.accept(this); + children.push_back(op->condition.get()); + + set_costs(false, op, children); + + op->body.accept(this); + children.push_back(op->body.get()); + + set_costs(true, op, children); +} + +void FindStmtCost::visit(const Block *op) { + vector children; + + op->first.accept(this); + children.push_back(op->first.get()); + + if (op->rest.defined()) { + op->rest.accept(this); + children.push_back(op->rest.get()); + } + + set_costs(true, op, children); + + // there is no exclusive computation or data movement for Block + set_costs(false, op, {}); +} + +void FindStmtCost::visit(const Fork *op) { + op->first.accept(this); + + vector children; + children.push_back(op->first.get()); + + if (op->rest.defined()) { + op->rest.accept(this); + children.push_back(op->rest.get()); + } + + set_costs(true, op, children); + set_costs(false, op, children); +} + +void FindStmtCost::visit(const IfThenElse *op) { + vector main_node_children; + + const IfThenElse *original_op = op; + + while (true) { + op->condition.accept(this); + op->then_case.accept(this); + + main_node_children.push_back(op->condition.get()); + main_node_children.push_back(op->then_case.get()); + + // inclusive and exclusive costs are the same + set_costs(false, op, {op->condition.get(), op->then_case.get()}); + set_costs(true, op, {op->condition.get(), op->then_case.get()}); + + // if there is no else case, we are done + if (!op->else_case.defined()) { + break; + } + + // if else case is another ifthenelse, recurse and reset op to else case + if (const IfThenElse *nested_if = op->else_case.as()) { + op = nested_if; + } + + // if else case is not another ifthenelse + else { + op->else_case.accept(this); + main_node_children.push_back(op->else_case.get()); + break; + } + } + + // set op costs - for entire if-statement, inclusive and exclusive costs are the same + set_costs(false, original_op, main_node_children); + set_costs(true, original_op, main_node_children); +} + +void FindStmtCost::visit(const Evaluate *op) { + op->value.accept(this); + + vector costs_children = get_costs_children(op, {op->value.get()}, true); + + // inclusive and exclusive costs are the same + set_costs(true, op, {op->value.get()}); + set_costs(false, op, {op->value.get()}); +} + +void FindStmtCost::visit(const Atomic *op) { + op->body.accept(this); + + ostringstream name; + name << op->producer_name; + + set_costs(true, op, {op->body.get()}); + set_costs(false, op, {}); +} + +string FindStmtCost::print_node(const IRNode *node) const { + ostringstream s; + s << "Node in question has type: "; + IRNodeType type = node->node_type; + switch (type) { + case IRNodeType::IntImm: { + s << "IntImm type"; + const auto *node1 = dynamic_cast(node); + s << "value: " << node1->value; + break; + } + case IRNodeType::UIntImm: { + s << "UIntImm type"; + break; + } + case IRNodeType::FloatImm: { + s << "FloatImm type"; + break; + } + case IRNodeType::StringImm: { + s << "StringImm type"; + break; + } + case IRNodeType::Broadcast: { + s << "Broadcast type"; + break; + } + case IRNodeType::Cast: { + s << "Cast type"; + break; + } + case IRNodeType::Variable: { + const auto *node1 = dynamic_cast(node); + s << "Variable type - " << node1->name; + break; + } + case IRNodeType::Add: { + s << "Add type"; + break; + } + case IRNodeType::Sub: { + s << "Sub type"; + break; + } + case IRNodeType::Mod: { + s << "Mod type"; + break; + } + case IRNodeType::Mul: { + s << "Mul type"; + break; + } + case IRNodeType::Div: { + s << "Div type"; + break; + } + case IRNodeType::Min: { + s << "Min type"; + break; + } + case IRNodeType::Max: { + s << "Max type"; + break; + } + case IRNodeType::EQ: { + s << "EQ type"; + break; + } + case IRNodeType::NE: { + s << "NE type"; + break; + } + case IRNodeType::LT: { + s << "LT type"; + break; + } + case IRNodeType::LE: { + s << "LE type"; + break; + } + case IRNodeType::GT: { + s << "GT type"; + break; + } + case IRNodeType::GE: { + s << "GE type"; + break; + } + case IRNodeType::And: { + s << "And type"; + break; + } + case IRNodeType::Or: { + s << "Or type"; + break; + } + case IRNodeType::Not: { + s << "Not type"; + break; + } + case IRNodeType::Select: { + s << "Select type"; + break; + } + case IRNodeType::Load: { + s << "Load type: "; + const auto *node1 = dynamic_cast(node); + s << node1->name << ", index: " << node1->index; + break; + } + case IRNodeType::Ramp: { + s << "Ramp type"; + break; + } + case IRNodeType::Call: { + s << "Call type"; + break; + } + case IRNodeType::Let: { + s << "Let type"; + break; + } + case IRNodeType::Shuffle: { + s << "Shuffle type"; + break; + } + case IRNodeType::VectorReduce: { + s << "VectorReduce type"; + break; + } + case IRNodeType::LetStmt: { + s << "LetStmt type"; + const auto *node1 = dynamic_cast(node); + s << "name: " << node1->name; + s << ", value: " << node1->value; + break; + } + case IRNodeType::AssertStmt: { + s << "AssertStmt type"; + break; + } + case IRNodeType::ProducerConsumer: { + s << "ProducerConsumer type"; + break; + } + case IRNodeType::For: { + s << "For type"; + break; + } + case IRNodeType::Acquire: { + s << "Acquire type"; + break; + } + case IRNodeType::Store: { + s << "Store type: "; + const auto *node1 = dynamic_cast(node); + s << node1->name << ", index: " << node1->index; + s << ", value: " << node1->value; + break; + } + case IRNodeType::Provide: { + s << "Provide type"; + break; + } + case IRNodeType::Allocate: { + s << "Allocate type"; + break; + } + case IRNodeType::Free: { + s << "Free type"; + break; + } + case IRNodeType::Realize: { + s << "Realize type"; + break; + } + case IRNodeType::Block: { + s << "Block type"; + break; + } + case IRNodeType::Fork: { + s << "Fork type"; + break; + } + case IRNodeType::IfThenElse: { + const auto *node1 = dynamic_cast(node); + s << "IfThenElse type - cond: " << node1->condition; + break; + } + case IRNodeType::Evaluate: { + s << "Evaluate type"; + break; + } + case IRNodeType::Prefetch: { + s << "Prefetch type"; + break; + } + case IRNodeType::Atomic: { + s << "Atomic type"; + break; + } + case IRNodeType::Reinterpret: { + s << "Reinterpret type"; + break; + } + default: { + s << "Unknown type"; + break; + } + } + + s << "\n"; + return s.str(); +} + +} // namespace Internal +} // namespace Halide diff --git a/src/FindStmtCost.h b/src/FindStmtCost.h new file mode 100644 index 000000000000..87d3ba6dc872 --- /dev/null +++ b/src/FindStmtCost.h @@ -0,0 +1,155 @@ +#ifndef HALIDE_FIND_STMT_COST_H +#define HALIDE_FIND_STMT_COST_H + +#include "Error.h" +#include "ExternFuncArgument.h" +#include "Function.h" +#include "IRVisitor.h" +#include "Module.h" + +#include + +namespace Halide { +namespace Internal { + +// Different classes of costs +enum StmtCostModel { Compute, + DataMovement, + ComputeInclusive, + DataMovementInclusive }; + +/* + * StmtCost struct + */ +struct StmtCost { + // DMC == Data Movement Cost, CC == Compute Cost + static constexpr int NormalNodeCC = 1; + static constexpr int NormalNodeDMC = 0; + static constexpr int LoadDMC = 2; + static constexpr int StoreDMC = 3; + + int depth; // per nested loop + int computation_cost_inclusive; // per entire node (includes cost of body) + int data_movement_cost_inclusive; // per entire node (includes cost of body) + int computation_cost_exclusive; // per line + int data_movement_cost_exclusive; // per line +}; + +/* + * FindStmtCost class + */ +class FindStmtCost : public IRVisitor { + +public: + FindStmtCost() + : current_loop_depth(0), max_computation_cost_inclusive(0), + max_data_movement_cost_inclusive(0), max_computation_cost_exclusive(0), + max_data_movement_cost_exclusive(0) { + } + + // starts the traversal of the given node + void generate_costs(const Module &m); + + // checks if node is IfThenElse or something else - will call get_if_node_cost if it is, + // get_computation_cost/get_data_movement_cost otherwise + int get_cost(const IRNode *node, StmtCostModel cost_model) const; + + // gets the depth of the node + int get_depth(const IRNode *node) const; + + // gets max costs + int get_max_cost(StmtCostModel cost_model) const; + + // prints node type + std::string print_node(const IRNode *node) const; + +private: + std::unordered_map stmt_cost; // key: node, value: cost + int current_loop_depth; // stores current loop depth level + + // these are used for determining the range of the cost + int max_computation_cost_inclusive; + int max_data_movement_cost_inclusive; + int max_computation_cost_exclusive; + int max_data_movement_cost_exclusive; + + // starts the traversal based on Module + void traverse(const Module &m); + + // gets the total costs of a node + int get_computation_cost(const IRNode *node, bool inclusive) const; + int get_data_movement_cost(const IRNode *node, bool inclusive) const; + + // treats if nodes differently when visualizing cost - will have cost be: + // cost of condition + cost of then_case (exclude else_case in cost) + int get_if_node_cost(const IfThenElse *op, StmtCostModel cost_model) const; + + // gets costs from `stmt_cost` map of children nodes and sum them up accordingly + std::vector get_costs_children(const IRNode *parent, const std::vector &children, + bool inclusive) const; + + // sets inclusive/exclusive costs + void set_costs(bool inclusive, const IRNode *node, const std::vector &children, + const std::function &calculate_cc, const std::function &calculate_dmc); + + // sets max computation cost and max data movement cost (inclusive and exclusive) + void set_max_costs(const Module &m); + + // gets scaling factor for Load/Store based on lanes and bits + int get_scaling_factor(uint8_t bits, uint16_t lanes) const; + + void visit_binary_op(const IRNode *op, const Expr &a, const Expr &b); + + void visit(const IntImm *op) override; + void visit(const UIntImm *op) override; + void visit(const FloatImm *op) override; + void visit(const StringImm *op) override; + void visit(const Cast *op) override; + void visit(const Reinterpret *op) override; + void visit(const Variable *op) override; + void visit(const Add *op) override; + void visit(const Sub *op) override; + void visit(const Mul *op) override; + void visit(const Div *op) override; + void visit(const Mod *op) override; + void visit(const Min *op) override; + void visit(const Max *op) override; + void visit(const EQ *op) override; + void visit(const NE *op) override; + void visit(const LT *op) override; + void visit(const LE *op) override; + void visit(const GT *op) override; + void visit(const GE *op) override; + void visit(const And *op) override; + void visit(const Or *op) override; + void visit(const Not *op) override; + void visit(const Select *op) override; + void visit(const Load *op) override; + void visit(const Ramp *op) override; + void visit(const Broadcast *op) override; + void visit(const Call *op) override; + void visit(const Let *op) override; + void visit(const Shuffle *op) override; + void visit(const VectorReduce *op) override; + void visit(const LetStmt *op) override; + void visit(const AssertStmt *op) override; + void visit(const ProducerConsumer *op) override; + void visit(const For *op) override; + void visit(const Acquire *op) override; + void visit(const Store *op) override; + void visit(const Provide *op) override; + void visit(const Allocate *op) override; + void visit(const Free *op) override; + void visit(const Realize *op) override; + void visit(const Prefetch *op) override; + void visit(const Block *op) override; + void visit(const Fork *op) override; + void visit(const IfThenElse *op) override; + void visit(const Evaluate *op) override; + void visit(const Atomic *op) override; +}; + +} // namespace Internal +} // namespace Halide + +#endif // FINDSTMTCOST_H diff --git a/src/Generator.cpp b/src/Generator.cpp index 5a315fe351e6..af5931a624b2 100644 --- a/src/Generator.cpp +++ b/src/Generator.cpp @@ -652,7 +652,7 @@ gengen -e A comma separated list of files to emit. Accepted values are: [assembly, bitcode, c_header, c_source, cpp_stub, featurization, llvm_assembly, object, python_extension, pytorch_wrapper, registration, - schedule, static_library, stmt, stmt_html, compiler_log]. + schedule, static_library, stmt, stmt_html, stmt_viz, compiler_log]. If omitted, default value is [c_header, static_library, registration]. -p A comma-separated list of shared libraries that will be loaded before the @@ -785,6 +785,13 @@ gengen output_types.insert(OutputFileType::registration); output_types.insert(OutputFileType::static_library); } else { + // if emit_flags contains "stmt_viz" but not "assembly", throw an error + bool has_stmt_viz = std::find(emit_flags.begin(), emit_flags.end(), "stmt_viz") != emit_flags.end(); + bool has_assembly = std::find(emit_flags.begin(), emit_flags.end(), "assembly") != emit_flags.end(); + + user_assert(!has_stmt_viz || has_assembly) + << "Output flag `stmt_viz` requires the `assembly` flag to also be set."; + // Build a reverse lookup table. Allow some legacy aliases on the command line, // to allow legacy build systems to work more easily. std::map output_name_to_enum = { diff --git a/src/GetAssemblyInfoViz.cpp b/src/GetAssemblyInfoViz.cpp new file mode 100644 index 000000000000..b2845a06f675 --- /dev/null +++ b/src/GetAssemblyInfoViz.cpp @@ -0,0 +1,199 @@ +#include "GetAssemblyInfoViz.h" + +#include + +namespace Halide { +namespace Internal { + +using std::string; + +void GetAssemblyInfoViz::generate_assembly_information(const Module &m, + const string &assembly_filename) { + // traverse the module to get the assembly markers + traverse(m); + + generate_assembly_html_and_line_numbers(assembly_filename); +} + +string GetAssemblyInfoViz::get_assembly_html() { + return assembly_HTML.str(); +} + +int GetAssemblyInfoViz::get_line_number_prod_cons(const IRNode *op) { + auto it = node_to_line_number_prod_cons.find(op); + if (it != node_to_line_number_prod_cons.end()) { + return it->second; + } else { + return -1; + } +} + +ForLoopLineNumber GetAssemblyInfoViz::get_line_numbers_for_loops(const IRNode *op) { + auto it = node_to_line_numbers_for_loops.find(op); + if (it != node_to_line_numbers_for_loops.end()) { + return it->second; + } else { + return {-1, -1}; + } +} + +void GetAssemblyInfoViz::traverse(const Module &m) { + + // traverse all functions + for (const auto &f : m.functions()) { + f.body.accept(this); + } +} + +void GetAssemblyInfoViz::generate_assembly_html_and_line_numbers(const string &filename) { + assembly_HTML << "\n"; +} + +string GetAssemblyInfoViz::get_assembly_filename(const string &filename) { + string assembly_filename = filename; + assembly_filename.replace(assembly_filename.find(".stmt.viz.html"), 15, ".s"); + return assembly_filename; +} + +void GetAssemblyInfoViz::add_line_number(string &assembly_line, int line_number) { + for (auto &marker : for_loop_markers) { + add_line_number_for_loop(assembly_line, marker, line_number); + } + for (auto &marker : producer_consumer_markers) { + add_line_number_prod_cons(assembly_line, marker, line_number); + } +} + +void GetAssemblyInfoViz::add_line_number_for_loop(string &assembly_line, + AssemblyInfoForLoop &marker, int line_number) { + // start of for loop + if (std::regex_search(assembly_line, marker.regex_start)) { + + // check if marker is already present + auto it = node_to_line_numbers_for_loops.find(marker.node); + if (it == node_to_line_numbers_for_loops.end()) { + ForLoopLineNumber for_loop_line_number; + for_loop_line_number.start_line = line_number; + for_loop_line_number.end_line = -1; + node_to_line_numbers_for_loops[marker.node] = for_loop_line_number; + } else { + it->second.start_line = line_number; + } + } + + // end of for loop + if (std::regex_search(assembly_line, marker.regex_end)) { + + // check if marker is already present + auto it = node_to_line_numbers_for_loops.find(marker.node); + if (it == node_to_line_numbers_for_loops.end()) { + ForLoopLineNumber for_loop_line_number; + for_loop_line_number.end_line = line_number; + for_loop_line_number.start_line = -1; + node_to_line_numbers_for_loops[marker.node] = for_loop_line_number; + } else { + it->second.end_line = line_number; + } + } +} +void GetAssemblyInfoViz::add_line_number_prod_cons(string &assembly_line, + AssemblyInfoProdCons &marker, int line_number) { + if (std::regex_search(assembly_line, marker.regex)) { + node_to_line_number_prod_cons[marker.node] = line_number; + } +} + +void GetAssemblyInfoViz::visit(const ProducerConsumer *op) { + producer_consumer_count++; + + string assembly_marker = "%\""; + assembly_marker += std::to_string(producer_consumer_count); + assembly_marker += op->is_producer ? "_produce " : "_consume "; + assembly_marker += op->name; + + // replace all $ with \$ + assembly_marker = replace_all(assembly_marker, "$", "\\$"); + + std::regex regex(assembly_marker); + + AssemblyInfoProdCons info; + info.regex = regex; + info.node = op; + + producer_consumer_markers.push_back(info); + + op->body.accept(this); +} +void GetAssemblyInfoViz::visit(const For *op) { + for_loop_count++; + + // start of for loop + string assembly_marker_start = "%\""; + assembly_marker_start += std::to_string(for_loop_count); + assembly_marker_start += "_for " + op->name; + + // replace all $ with \$ + std::regex dollar("\\$"); + assembly_marker_start = std::regex_replace(assembly_marker_start, dollar, "\\$"); + + std::regex regex_start(assembly_marker_start); + + // end of for loop + string assembly_marker_end = "%\""; + assembly_marker_end += std::to_string(for_loop_count); + assembly_marker_end += "_end for " + op->name; + + // replace all $ with \$ + assembly_marker_end = std::regex_replace(assembly_marker_end, dollar, "\\$"); + + std::regex regex_end(assembly_marker_end); + + AssemblyInfoForLoop info; + info.regex_start = regex_start; + info.regex_end = regex_end; + info.node = op; + + for_loop_markers.push_back(info); + + op->body.accept(this); +} + +string GetAssemblyInfoViz::print_node(const IRNode *node) const { + std::stringstream s; + IRNodeType type = node->node_type; + if (type == IRNodeType::ProducerConsumer) { + s << "ProducerConsumer"; + const auto *node1 = dynamic_cast(node); + s << " " << node1->name; + } else if (type == IRNodeType::For) { + s << "For"; + const auto *node1 = dynamic_cast(node); + s << " " << node1->name; + } else { + s << "Unknown type "; + } + + return s.str(); +} + +} // namespace Internal +} // namespace Halide diff --git a/src/GetAssemblyInfoViz.h b/src/GetAssemblyInfoViz.h new file mode 100644 index 000000000000..aa7544113bbb --- /dev/null +++ b/src/GetAssemblyInfoViz.h @@ -0,0 +1,91 @@ +#ifndef HALIDE_GET_ASSEMBLY_INFO_VIZ_H +#define HALIDE_GET_ASSEMBLY_INFO_VIZ_H + +#include +#include + +#include "IROperator.h" +#include "IRVisitor.h" +#include "Module.h" + +namespace Halide { + +class Module; + +namespace Internal { + +struct AssemblyInfoForLoop { + std::regex regex_start; // regex to match the starting marker with + std::regex regex_end; // regex to match the ending marker + const IRNode *node; // node that the marker is associated with +}; + +struct AssemblyInfoProdCons { + std::regex regex; // regex to match the marker with + const IRNode *node; // node that the marker is associated with +}; + +struct ForLoopLineNumber { + int start_line; // line number of the start of the for loop + int end_line; // line number of the end of the for loop +}; + +class GetAssemblyInfoViz : public IRVisitor { + +public: + // generates the assembly info for the module + void generate_assembly_information(const Module &m, const std::string &assembly_filename); + + // returns html content that contains the assembly code + std::string get_assembly_html(); + + // gets line numbers for producers/consumers + for loops + int get_line_number_prod_cons(const IRNode *op); + ForLoopLineNumber get_line_numbers_for_loops(const IRNode *op); + +private: + using IRVisitor::visit; + + // main html content + std::ostringstream assembly_HTML; + + // stores mapping of node to line number + std::unordered_map node_to_line_number_prod_cons; + std::unordered_map node_to_line_numbers_for_loops; + + // stores the markers + std::vector for_loop_markers; + std::vector producer_consumer_markers; + + // for maping each node to unique marker in assembly + int for_loop_count = 0; + int producer_consumer_count = 0; + + // traverses the module to generate the assembly markers + void traverse(const Module &m); + + // generates the assembly html and line numbers from the loaded assembly file + // and generated markers + void generate_assembly_html_and_line_numbers(const std::string &filename); + + // gets assembly file from stmt.viz.html file + std::string get_assembly_filename(const std::string &filename); + + // checks if there is a marker that matches the assembly line, and if so, adds the line + // number and node to map, signifying a match + void add_line_number(std::string &assembly_line, int line_number); + void add_line_number_for_loop(std::string &assembly_line, AssemblyInfoForLoop &marker, + int line_number); + void add_line_number_prod_cons(std::string &assembly_line, AssemblyInfoProdCons &marker, + int line_number); + + void visit(const ProducerConsumer *op) override; + void visit(const For *op) override; + + std::string print_node(const IRNode *node) const; +}; + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/GetStmtHierarchy.cpp b/src/GetStmtHierarchy.cpp new file mode 100644 index 000000000000..02a181322e11 --- /dev/null +++ b/src/GetStmtHierarchy.cpp @@ -0,0 +1,728 @@ +#include "GetStmtHierarchy.h" + +namespace Halide { +namespace Internal { + +using std::ostringstream; +using std::string; + +StmtHierarchyInfo GetStmtHierarchy::get_hierarchy_html(const Expr &node) { + reset_variables(); + + int start_node = curr_node_ID; + html << start_tree(); + node.accept(this); + html << end_tree(); + int end_node = num_nodes; + + StmtHierarchyInfo info; + info.html = html.str(); + info.viz_num = viz_counter; + info.start_node = start_node; + info.end_node = end_node; + + return info; +} +StmtHierarchyInfo GetStmtHierarchy::get_hierarchy_html(const Stmt &node) { + reset_variables(); + + int start_node = curr_node_ID; + html << start_tree(); + node.accept(this); + html << end_tree(); + int end_node = num_nodes; + + StmtHierarchyInfo info; + info.html = html.str(); + info.viz_num = viz_counter; + info.start_node = start_node; + info.end_node = end_node; + + return info; +} + +StmtHierarchyInfo GetStmtHierarchy::get_else_hierarchy_html() { + reset_variables(); + + int start_node = curr_node_ID; + html << start_tree(); + html << node_without_children(nullptr, "else"); + html << end_tree(); + int end_node = num_nodes; + + StmtHierarchyInfo info; + info.html = html.str(); + info.viz_num = viz_counter; + info.start_node = start_node; + info.end_node = end_node; + + return info; +} + +void GetStmtHierarchy::update_num_nodes() { + num_nodes++; + curr_node_ID = num_nodes; +} + +string GetStmtHierarchy::get_node_class_name() { + ostringstream class_name; + if (curr_node_ID == start_node_id) { + class_name << "viz" << viz_counter << " startNode depth" << node_depth; + } else { + class_name << "viz" << viz_counter << " node" << curr_node_ID << "child depth" + << node_depth; + } + return class_name.str(); +} + +void GetStmtHierarchy::reset_variables() { + html.str(""); + num_nodes++; + curr_node_ID = num_nodes; + start_node_id = -1; + node_depth = 0; + start_node_id = num_nodes; + viz_counter++; +} + +string GetStmtHierarchy::start_tree() const { + ostringstream ss; + ss << "
"; + ss << "
"; + ss << "
    "; + return ss.str(); +} +string GetStmtHierarchy::end_tree() const { + ostringstream ss; + ss << "
"; + ss << "
"; + ss << "
"; + return ss.str(); +} + +string GetStmtHierarchy::generate_computation_cost_div(const IRNode *op) { + stmt_hierarchy_tooltip_count++; + + ostringstream ss; + string tooltip_text = ir_viz.generate_computation_cost_tooltip(op, ""); + + // tooltip span + ss << "" << tooltip_text + << ""; + + // color div + int computation_range = ir_viz.get_color_range(op, StmtCostModel::Compute); + string class_name = "computation-cost-div CostColor" + std::to_string(computation_range); + ss << "
"; + ss << "
"; + + return ss.str(); +} +string GetStmtHierarchy::generate_memory_cost_div(const IRNode *op) { + stmt_hierarchy_tooltip_count++; + + ostringstream ss; + string tooltip_text = ir_viz.generate_data_movement_cost_tooltip(op, ""); + + // tooltip span + ss << "" << tooltip_text + << ""; + + // color div + int data_movement_range = ir_viz.get_color_range(op, StmtCostModel::DataMovement); + string class_name = "memory-cost-div CostColor" + std::to_string(data_movement_range); + ss << "
" + << "
"; + + return ss.str(); +} + +string GetStmtHierarchy::node_without_children(const IRNode *op, const string &name) { + ostringstream ss; + + string class_name = get_node_class_name(); + ss << "
  • " + << ""; + + ss << "
    "; + ss << generate_computation_cost_div(op); + ss << generate_memory_cost_div(op); + + ss << "
    " << name << "
    " + << "
    " + << "
    " + << "
  • "; + + return ss.str(); +} +string GetStmtHierarchy::open_node(const IRNode *op, const string &name) { + ostringstream ss; + string class_name = get_node_class_name() + " children-node"; + + update_num_nodes(); + + ss << "
  • "; + ss << ""; + + ss << "
    "; + ss << generate_computation_cost_div(op); + ss << generate_memory_cost_div(op); + + ss << "
    " << name + << "" + << "
    " + << "
    " + << "
    "; + + node_depth++; + ss << "
      "; + + return ss.str(); +} +string GetStmtHierarchy::close_node() { + node_depth--; + ostringstream ss; + ss << "
    "; + ss << "
  • "; + return ss.str(); +} + +void GetStmtHierarchy::visit(const IntImm *op) { + html << node_without_children(op, std::to_string(op->value)); +} +void GetStmtHierarchy::visit(const UIntImm *op) { + html << node_without_children(op, std::to_string(op->value)); +} +void GetStmtHierarchy::visit(const FloatImm *op) { + html << node_without_children(op, std::to_string(op->value)); +} +void GetStmtHierarchy::visit(const StringImm *op) { + html << node_without_children(op, op->value); +} +void GetStmtHierarchy::visit(const Cast *op) { + ostringstream name; + name << op->type; + html << open_node(op, name.str()); + op->value.accept(this); + html << close_node(); +} +void GetStmtHierarchy::visit(const Reinterpret *op) { + ostringstream name; + name << "reinterpret "; + name << op->type; + html << open_node(op, name.str()); + op->value.accept(this); + html << close_node(); +} +void GetStmtHierarchy::visit(const Variable *op) { + html << node_without_children(op, op->name); +} + +void GetStmtHierarchy::visit_binary_op(const IRNode *op, const Expr &a, const Expr &b, + const string &name) { + html << open_node(op, name); + + int curr_node = curr_node_ID; + a.accept(this); + + curr_node_ID = curr_node; + b.accept(this); + + html << close_node(); +} + +void GetStmtHierarchy::visit(const Add *op) { + visit_binary_op(op, op->a, op->b, "+"); +} +void GetStmtHierarchy::visit(const Sub *op) { + visit_binary_op(op, op->a, op->b, "-"); +} +void GetStmtHierarchy::visit(const Mul *op) { + visit_binary_op(op, op->a, op->b, "*"); +} +void GetStmtHierarchy::visit(const Div *op) { + visit_binary_op(op, op->a, op->b, "/"); +} +void GetStmtHierarchy::visit(const Mod *op) { + visit_binary_op(op, op->a, op->b, "%"); +} +void GetStmtHierarchy::visit(const EQ *op) { + visit_binary_op(op, op->a, op->b, "=="); +} +void GetStmtHierarchy::visit(const NE *op) { + visit_binary_op(op, op->a, op->b, "!="); +} +void GetStmtHierarchy::visit(const LT *op) { + visit_binary_op(op, op->a, op->b, "<"); +} +void GetStmtHierarchy::visit(const LE *op) { + visit_binary_op(op, op->a, op->b, "<="); +} +void GetStmtHierarchy::visit(const GT *op) { + visit_binary_op(op, op->a, op->b, ">"); +} +void GetStmtHierarchy::visit(const GE *op) { + visit_binary_op(op, op->a, op->b, ">="); +} +void GetStmtHierarchy::visit(const And *op) { + visit_binary_op(op, op->a, op->b, "&&"); +} +void GetStmtHierarchy::visit(const Or *op) { + visit_binary_op(op, op->a, op->b, "||"); +} +void GetStmtHierarchy::visit(const Min *op) { + visit_binary_op(op, op->a, op->b, "min"); +} +void GetStmtHierarchy::visit(const Max *op) { + visit_binary_op(op, op->a, op->b, "max"); +} + +void GetStmtHierarchy::visit(const Not *op) { + html << open_node(op, "!"); + op->a.accept(this); + html << close_node(); +} +void GetStmtHierarchy::visit(const Select *op) { + html << open_node(op, "Select"); + + int curr_node = curr_node_ID; + op->condition.accept(this); + + curr_node_ID = curr_node; + op->true_value.accept(this); + + curr_node_ID = curr_node; + op->false_value.accept(this); + + html << close_node(); +} +void GetStmtHierarchy::visit(const Load *op) { + ostringstream index; + index << op->index; + html << node_without_children(op, op->name + "[" + index.str() + "]"); +} +void GetStmtHierarchy::visit(const Ramp *op) { + html << open_node(op, "Ramp"); + + int curr_node = curr_node_ID; + op->base.accept(this); + + curr_node_ID = curr_node; + op->stride.accept(this); + + curr_node_ID = curr_node; + Expr(op->lanes).accept(this); + + html << close_node(); +} +void GetStmtHierarchy::visit(const Broadcast *op) { + html << open_node(op, "x" + std::to_string(op->lanes)); + op->value.accept(this); + html << close_node(); +} +void GetStmtHierarchy::visit(const Call *op) { + html << open_node(op, op->name); + + int curr_node = curr_node_ID; + for (const auto &arg : op->args) { + curr_node_ID = curr_node; + arg.accept(this); + } + + html << close_node(); +} +void GetStmtHierarchy::visit(const Let *op) { + html << open_node(op, "Let"); + int curr_node = curr_node_ID; + + html << open_node(op->value.get(), "Let"); + html << node_without_children(nullptr, op->name); + op->value.accept(this); + html << close_node(); + + // "body" node + curr_node_ID = curr_node; + html << open_node(op->body.get(), "body"); + op->body.accept(this); + html << close_node(); + + html << close_node(); +} +void GetStmtHierarchy::visit(const LetStmt *op) { + html << open_node(op, "Let"); + + int curr_node = curr_node_ID; + html << node_without_children(nullptr, op->name); + + curr_node_ID = curr_node; + op->value.accept(this); + + html << close_node(); +} +void GetStmtHierarchy::visit(const AssertStmt *op) { + html << open_node(op, "Assert"); + + int curr_node = curr_node_ID; + op->condition.accept(this); + + curr_node_ID = curr_node; + op->message.accept(this); + html << close_node(); +} +void GetStmtHierarchy::visit(const ProducerConsumer *op) { + string node_name = op->is_producer ? "Produce" : "Consume"; + node_name += " " + op->name; + html << node_without_children(op, node_name); +} +void GetStmtHierarchy::visit(const For *op) { + html << open_node(op, "For"); + + int curr_node = curr_node_ID; + html << open_node(nullptr, "loop var"); + html << node_without_children(nullptr, op->name); + html << close_node(); + + curr_node_ID = curr_node; + html << open_node(op->min.get(), "min"); + op->min.accept(this); + html << close_node(); + + curr_node_ID = curr_node; + html << open_node(op->extent.get(), "extent"); + op->extent.accept(this); + html << close_node(); + + html << close_node(); +} +void GetStmtHierarchy::visit(const Store *op) { + html << open_node(op, "Store"); + + ostringstream index; + index << op->index; + html << node_without_children(op->index.get(), op->name + "[" + index.str() + "]"); + + op->value.accept(this); + html << close_node(); +} +void GetStmtHierarchy::visit(const Provide *op) { + html << open_node(op, "Provide"); + int curr_node0 = curr_node_ID; + + html << open_node(op, op->name); + int curr_node1 = curr_node_ID; + for (const auto &arg : op->args) { + curr_node_ID = curr_node1; + arg.accept(this); + } + html << close_node(); + + for (const auto &val : op->values) { + curr_node_ID = curr_node0; + val.accept(this); + } + html << close_node(); +} +void GetStmtHierarchy::visit(const Allocate *op) { + html << open_node(op, "allocate"); + + ostringstream index; + index << op->type; + + for (const auto &extent : op->extents) { + index << " * "; + index << extent; + } + + html << node_without_children(op, op->name + "[" + index.str() + "]"); + + ostringstream name; + if (!is_const_one(op->condition)) { + name << " if " << op->condition; + } + + if (op->new_expr.defined()) { + internal_assert(false) << "\n" + << "GetStmtHierarchy: Allocate " << op->name + << " `op->new_expr.defined()` is not supported yet.\n\n"; + } + if (!op->free_function.empty()) { + name << "custom_delete {" << op->free_function << "}"; + } + + if (!name.str().empty()) { + html << node_without_children(op, name.str()); + } + html << close_node(); +} +void GetStmtHierarchy::visit(const Free *op) { + html << open_node(op, "Free"); + html << node_without_children(op, op->name); + html << close_node(); +} +void GetStmtHierarchy::visit(const Realize *op) { + internal_assert(false) << "\n" + << "GetStmtHierarchy: Realize is not supported yet \n\n"; +} +void GetStmtHierarchy::visit(const Block *op) { + internal_assert(false) << "\n" + << "GetStmtHierarchy: Block is not supported and should never be visualized. \n\n"; +} +void GetStmtHierarchy::visit(const IfThenElse *op) { + html << open_node(op, "If"); + + html << open_node(op->condition.get(), "condition"); + op->condition.accept(this); + html << close_node(); + + // don't visualize else case because that will be visualized later as another IfThenElse block + // in StmtToViz.cpp + + html << close_node(); +} +void GetStmtHierarchy::visit(const Evaluate *op) { + op->value.accept(this); +} +void GetStmtHierarchy::visit(const Shuffle *op) { + if (op->is_concat()) { + html << open_node(op, "concat_vectors"); + + int curr_node = curr_node_ID; + for (const auto &e : op->vectors) { + curr_node_ID = curr_node; + e.accept(this); + } + html << close_node(); + } + + else if (op->is_interleave()) { + html << open_node(op, "interleave_vectors"); + + int curr_node = curr_node_ID; + for (const auto &e : op->vectors) { + curr_node_ID = curr_node; + e.accept(this); + } + html << close_node(); + } + + else if (op->is_extract_element()) { + std::vector args = op->vectors; + args.emplace_back(op->slice_begin()); + html << open_node(op, "extract_element"); + + int curr_node = curr_node_ID; + for (auto &e : args) { + curr_node_ID = curr_node; + e.accept(this); + } + html << close_node(); + } + + else if (op->is_slice()) { + std::vector args = op->vectors; + args.emplace_back(op->slice_begin()); + args.emplace_back(op->slice_stride()); + args.emplace_back(static_cast(op->indices.size())); + html << open_node(op, "slice_vectors"); + + int curr_node = curr_node_ID; + for (auto &e : args) { + curr_node_ID = curr_node; + e.accept(this); + } + html << close_node(); + } + + else { + std::vector args = op->vectors; + for (int i : op->indices) { + args.emplace_back(i); + } + html << open_node(op, "Shuffle"); + + int curr_node = curr_node_ID; + for (auto &e : args) { + curr_node_ID = curr_node; + e.accept(this); + } + html << close_node(); + } +} +void GetStmtHierarchy::visit(const VectorReduce *op) { + html << open_node(op, "vector_reduce"); + + int curr_node = curr_node_ID; + ostringstream op_op; + op_op << op->op; + html << node_without_children(nullptr, op_op.str()); + + curr_node_ID = curr_node; + op->value.accept(this); + + html << close_node(); +} +void GetStmtHierarchy::visit(const Prefetch *op) { + internal_assert(false) << "\n" + << "GetStmtHierarchy: Prefetch is not supported yet. \n\n"; +} +void GetStmtHierarchy::visit(const Fork *op) { + internal_assert(false) << "\n" + << "GetStmtHierarchy: Fork is not supported yet. \n\n"; +} +void GetStmtHierarchy::visit(const Acquire *op) { + html << open_node(op, "acquire"); + + int curr_node = curr_node_ID; + op->semaphore.accept(this); + + curr_node_ID = curr_node; + op->count.accept(this); + + html << close_node(); +} +void GetStmtHierarchy::visit(const Atomic *op) { + if (op->mutex_name.empty()) { + html << node_without_children(op, "atomic"); + } else { + html << open_node(op, "atomic"); + html << node_without_children(nullptr, op->mutex_name); + html << close_node(); + } +} + +string GetStmtHierarchy::generate_stmt_hierarchy_js() { + ostringstream stmt_hierarchy_js; + + stmt_hierarchy_js << R"( +// stmtHierarchy JS +for (let i = 1; i <= )" << stmt_hierarchy_tooltip_count << R"(; i++) { + const button = document.getElementById('stmtHierarchyButtonTooltip' + i); + const tooltip = document.getElementById('stmtHierarchyTooltip' + i); + button.addEventListener('mouseenter', () => { + showTooltip(button, tooltip); + }); + button.addEventListener('mouseleave', () => { + hideTooltip(tooltip); + }); + tooltip.addEventListener('focus', () => { + showTooltip(button, tooltip); + }); + tooltip.addEventListener('blur', () => { + hideTooltip(tooltip); + }); +})"; + + return stmt_hierarchy_js.str(); +} + +const char *GetStmtHierarchy::stmt_hierarchy_css = + R"( +/* StmtHierarchy CSS */ +.arrow { border: solid rgb(125,125,125); border-width: 0 2px 2px 0; display: +inline-block; padding: 3px; } +.down { transform: rotate(45deg); -webkit-transform: rotate(45deg); } +.up { transform: rotate(-135deg); -webkit-transform: rotate(-135deg); } +.stmtHierarchyButton {padding: 3px;} +.tf-custom-stmtHierarchy .tf-nc { border-radius: 5px; border: 1px solid; font-size: 12px; border-color: rgb(200, 200, 200);} +.tf-custom-stmtHierarchy .end-node { border-style: dashed; font-size: 12px; } +.tf-custom-stmtHierarchy .tf-nc:before, .tf-custom-stmtHierarchy .tf-nc:after { border-left-width: 1px; border-color: rgb(200, 200, 200);} +.tf-custom-stmtHierarchy li li:before { border-top-width: 1px; border-color: rgb(200, 200, 200);} +.tf-custom-stmtHierarchy { font-size: 12px; } +div.nodeContent { display: flex; } +div.nodeName { padding-left: 5px; } +)"; + +const char *GetStmtHierarchy::stmt_hierarchy_collapse_expand_JS = R"( +// collapse/expand js (stmt hierarchy) +var nodeExpanded = new Map(); +function collapseAllNodes(startNode, endNode) { + for (let i = startNode; i <= endNode; i++) { + collapseNodeChildren(i); + nodeExpanded.set(i, false); + if (document.getElementById('stmtHierarchyButton' + i) != null) { + document.getElementById('stmtHierarchyButton' + i).className = 'arrow down'; + } + } +} +function expandNodesUpToDepth(depth, vizNum) { + for (let i = 0; i < depth; i++) { + const depthChildren = document.getElementsByClassName('viz' + vizNum + ' depth' + i); + for (const child of depthChildren) { + child.style.display = ''; + if (child.className.includes('start')) { + continue; + } + let parentNodeID = child.className.split()[0]; + parentNodeID = parentNodeID.split('node')[1]; + parentNodeID = parentNodeID.split('child')[0]; + const parentNode = parseInt(parentNodeID); + nodeExpanded.set(parentNode, true); + if (document.getElementById('stmtHierarchyButton' + parentNodeID) != null) { + document.getElementById('stmtHierarchyButton' + parentNodeID).className = 'arrow up'; + } + const dotdotdot = document.getElementById('node' + parentNodeID + 'dotdotdot'); + if (dotdotdot != null) { + dotdotdot.remove(); + } + } + } +} +function handleClick(nodeNum) { + if (nodeExpanded.get(nodeNum)) { + collapseNodeChildren(nodeNum); + nodeExpanded.set(nodeNum, false); + } else { + expandNodeChildren(nodeNum); + nodeExpanded.set(nodeNum, true); + } +} +function collapseNodeChildren(nodeNum) { + const children = document.getElementsByClassName('node' + nodeNum + 'child'); + if (document.getElementById('stmtHierarchyButton' + nodeNum) != null) { + document.getElementById('stmtHierarchyButton' + nodeNum).className = 'arrow down'; + } + for (const child of children) { + child.style.display = 'none'; + } + const list = document.getElementById('list' + nodeNum); + const parentNode = document.getElementById('node' + nodeNum); + if (list != null && parentNode != null) { + const span = parentNode.children[0]; + list.appendChild(addDotDotDotChild(nodeNum)); + } +} +function expandNodeChildren(nodeNum) { + const children = document.getElementsByClassName('node' + nodeNum + 'child'); + if (document.getElementById('stmtHierarchyButton' + nodeNum) != null) { + document.getElementById('stmtHierarchyButton' + nodeNum).className = 'arrow up'; + } + for (const child of children) { + child.style.display = ''; + } + const dotdotdot = document.getElementById('node' + nodeNum + 'dotdotdot'); + if (dotdotdot != null) { + dotdotdot.remove(); + } +} +function addDotDotDotChild(nodeNum, colorCost) { + var liDotDotDot = document.createElement('li'); + liDotDotDot.id = 'node' + nodeNum + 'dotdotdot'; + const span ="..."; + liDotDotDot.innerHTML = span; + return liDotDotDot; +} +)"; + +} // namespace Internal +} // namespace Halide diff --git a/src/GetStmtHierarchy.h b/src/GetStmtHierarchy.h new file mode 100644 index 000000000000..1c2ea235f089 --- /dev/null +++ b/src/GetStmtHierarchy.h @@ -0,0 +1,130 @@ +#ifndef HALIDE_GET_STMT_HIERARCHY_H +#define HALIDE_GET_STMT_HIERARCHY_H + +#include "FindStmtCost.h" +#include "IROperator.h" +#include "IRVisitor.h" +#include "IRVisualization.h" + +namespace Halide { +namespace Internal { + +struct StmtHierarchyInfo { + std::string html; // html code for the node + int viz_num; // id for that visualization + int start_node; // start node for the visualization + int end_node; // end node for the visualization +}; + +class GetStmtHierarchy : public IRVisitor { + +public: + static const char *stmt_hierarchy_css; + static const char *stmt_hierarchy_collapse_expand_JS; + + GetStmtHierarchy(const FindStmtCost &find_stmt_cost_populated) + : find_stmt_cost(find_stmt_cost_populated), ir_viz(find_stmt_cost_populated), + curr_node_ID(0), num_nodes(0), viz_counter(0), stmt_hierarchy_tooltip_count(0) { + } + + // returns the generated hierarchy's html + StmtHierarchyInfo get_hierarchy_html(const Expr &node); + StmtHierarchyInfo get_hierarchy_html(const Stmt &node); + + // special case for else case (node with just "else") + StmtHierarchyInfo get_else_hierarchy_html(); + + // generates the JS that is needed to add the tooltips + std::string generate_stmt_hierarchy_js(); + +private: + std::ostringstream html; // html string + FindStmtCost find_stmt_cost; // used as input to IRVisualization + IRVisualization ir_viz; // used to generate the tooltip information and cost colors + + // for expanding/collapsing nodes + int curr_node_ID; // ID of the current node in traversal + int num_nodes; // total number of nodes (across all generated trees in the IR) + int start_node_id; // ID of the start node of the current tree + int node_depth; // depth of the current node in the tree + int viz_counter; // counter for the number of visualizations + int stmt_hierarchy_tooltip_count; // tooltip count + + // updates the curr_node_ID to be the next available node ID (num_nodes) + // and increases num_nodes by 1 + void update_num_nodes(); + + // returns the class name in format "node[parentID]child depth[depth]" + std::string get_node_class_name(); + + // resets all the variables to start a new tree + void reset_variables(); + + // starts and ends a tree within the html file + std::string start_tree() const; + std::string end_tree() const; + + // creating color divs with tooltips + std::string generate_computation_cost_div(const IRNode *op); + std::string generate_memory_cost_div(const IRNode *op); + + // opens and closes nodes, depending on number of children + std::string node_without_children(const IRNode *op, const std::string &name); + std::string open_node(const IRNode *op, const std::string &name); + std::string close_node(); + + void visit_binary_op(const IRNode *op, const Expr &a, const Expr &b, const std::string &name); + + void visit(const IntImm *op) override; + void visit(const UIntImm *op) override; + void visit(const FloatImm *op) override; + void visit(const StringImm *op) override; + void visit(const Cast *op) override; + void visit(const Reinterpret *) override; + void visit(const Variable *op) override; + void visit(const Add *op) override; + void visit(const Sub *op) override; + void visit(const Mul *op) override; + void visit(const Div *op) override; + void visit(const Mod *op) override; + void visit(const Min *op) override; + void visit(const Max *op) override; + void visit(const EQ *op) override; + void visit(const NE *op) override; + void visit(const LT *op) override; + void visit(const LE *op) override; + void visit(const GT *op) override; + void visit(const GE *op) override; + void visit(const And *op) override; + void visit(const Or *op) override; + void visit(const Not *op) override; + void visit(const Select *op) override; + void visit(const Load *op) override; + void visit(const Ramp *op) override; + void visit(const Broadcast *op) override; + void visit(const Call *op) override; + void visit(const Let *op) override; + void visit(const Shuffle *op) override; + void visit(const VectorReduce *op) override; + void visit(const LetStmt *op) override; + void visit(const AssertStmt *op) override; + void visit(const ProducerConsumer *op) override; + void visit(const For *op) override; + void visit(const Acquire *op) override; + void visit(const Store *op) override; + void visit(const Provide *op) override; + void visit(const Allocate *op) override; + void visit(const Free *op) override; + void visit(const Realize *op) override; + void visit(const Prefetch *op) override; + void visit(const Block *op) override; + void visit(const Fork *op) override; + void visit(const IfThenElse *op) override; + void visit(const Evaluate *op) override; + void visit(const Atomic *op) override; +}; + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/IRVisualization.cpp b/src/IRVisualization.cpp new file mode 100644 index 000000000000..dbc5756f2765 --- /dev/null +++ b/src/IRVisualization.cpp @@ -0,0 +1,1613 @@ +#include "IRVisualization.h" + +#include "IROperator.h" +#include "Module.h" +#include + +namespace Halide { +namespace Internal { + +using std::ostringstream; +using std::pair; +using std::string; +using std::vector; + +constexpr int NUMBER_COST_COLORS = 20; + +/* + * GetReadWrite class + */ + +void GetReadWrite::generate_sizes(const Module &m) { + traverse(m); +} + +StmtSize GetReadWrite::get_size(const IRNode *node) const { + auto it = stmt_sizes.find(node); + + // errors if node is not found + internal_assert(it != stmt_sizes.end()) << "\n\nGetReadWrite::get_size - Node not found in stmt_sizes: " + << print_node(node) << "\n\n"; + + return (it != stmt_sizes.end()) ? it->second : StmtSize(); +} + +string GetReadWrite::string_span(const string &var_name) const { + return "" + var_name + ""; +} +string GetReadWrite::int_span(int64_t int_val) const { + return "" + std::to_string(int_val) + ""; +} + +void GetReadWrite::traverse(const Module &m) { + + // traverse all functions + for (const auto &f : m.functions()) { + function_names.push_back(f.name); + f.body.accept(this); + } +} + +string GetReadWrite::get_simplified_string(const string &a, const string &b, const string &op) { + if (op == "+") { + return a + " + " + b; + } + + else if (op == "*") { + // check if b contains "+" + if (b.find('+') != string::npos) { + return a + "*(" + b + ")"; + } else { + return a + "*" + b; + } + } + + else { + internal_assert(false) << "\n" + << "GetReadWrite::get_simplified_string - Unsupported operator: " << op + << "\n"; + return ""; + } +} + +void GetReadWrite::set_write_size(const IRNode *node, const string &write_var, string write_size) { + auto it = stmt_sizes.find(node); + if (it == stmt_sizes.end()) { + stmt_sizes[node] = StmtSize(); + } + stmt_sizes[node].writes[write_var] = std::move(write_size); +} +void GetReadWrite::set_read_size(const IRNode *node, const string &read_var, string read_size) { + auto it = stmt_sizes.find(node); + if (it == stmt_sizes.end()) { + stmt_sizes[node] = StmtSize(); + } + stmt_sizes[node].reads[read_var] = std::move(read_size); +} + +void GetReadWrite::visit(const Store *op) { + + uint16_t lanes = op->index.type().lanes(); + + set_write_size(op, op->name, int_span(lanes)); + + // empty curr_load_values + curr_load_values.clear(); + op->value.accept(this); + + // set consume (for now, read values) + for (const auto &load_var : curr_load_values) { + set_read_size(op, load_var.first, int_span(load_var.second)); + } +} +void GetReadWrite::add_load_value(const string &name, const int lanes) { + auto it = curr_load_values.find(name); + if (it == curr_load_values.end()) { + curr_load_values[name] = lanes; + } else { + curr_load_values[name] += lanes; + } +} +void GetReadWrite::visit(const Load *op) { + + int lanes = int(op->type.lanes()); + + add_load_value(op->name, lanes); +} + +/* + * IRVisualization class + */ +string IRVisualization::generate_ir_visualization_html(const Module &m) { + get_read_write.generate_sizes(m); + + html.str(""); + num_of_nodes = 0; + start_module_traversal(m); + + return html.str(); +} + +string IRVisualization::generate_computation_cost_tooltip(const IRNode *op, const string &extra_note) { + int depth, computation_cost_exclusive, computation_cost_inclusive; + + if (op == nullptr) { + depth = 0; + computation_cost_exclusive = StmtCost::NormalNodeCC; + } else { + depth = find_stmt_cost.get_depth(op); + } + + computation_cost_exclusive = get_cost_percentage(op, StmtCostModel::Compute); + computation_cost_inclusive = get_cost_percentage(op, StmtCostModel::ComputeInclusive); + + // build up values of the table that will be displayed + vector> table_rows; + table_rows.emplace_back("Loop Depth", std::to_string(depth)); + + if (computation_cost_exclusive == computation_cost_inclusive) { + table_rows.emplace_back("Computation Cost", std::to_string(computation_cost_exclusive) + "%"); + } else { + table_rows.emplace_back("Computation Cost (Exclusive)", std::to_string(computation_cost_exclusive) + "%"); + table_rows.emplace_back("Computation Cost (Inclusive)", std::to_string(computation_cost_inclusive) + "%"); + } + + return tooltip_table(table_rows, extra_note); +} +string IRVisualization::generate_data_movement_cost_tooltip(const IRNode *op, const string &extra_note) { + int depth, data_movement_cost_exclusive, data_movement_cost_inclusive; + + if (op == nullptr) { + depth = 0; + } else { + depth = find_stmt_cost.get_depth(op); + } + + data_movement_cost_exclusive = get_cost_percentage(op, StmtCostModel::DataMovement); + data_movement_cost_inclusive = get_cost_percentage(op, StmtCostModel::DataMovementInclusive); + + // build up values of the table that will be displayed + vector> table_rows; + table_rows.emplace_back("Loop Depth", std::to_string(depth)); + + if (data_movement_cost_exclusive == data_movement_cost_inclusive) { + table_rows.emplace_back("Data Movement Cost", std::to_string(data_movement_cost_exclusive) + "%"); + } else { + table_rows.emplace_back("Data Movement Cost (Exclusive)", std::to_string(data_movement_cost_exclusive) + "%"); + table_rows.emplace_back("Data Movement Cost (Inclusive)", std::to_string(data_movement_cost_inclusive) + "%"); + } + + return tooltip_table(table_rows, extra_note); +} + +int IRVisualization::get_color_range(const IRNode *op, StmtCostModel cost_model) const { + if (op == nullptr) { + return 0; + } + + // divide max cost by NUMBER_COST_COLORS and round up to get range size + int range_size = (find_stmt_cost.get_max_cost(cost_model) / NUMBER_COST_COLORS) + 1; + int cost = find_stmt_cost.get_cost(op, cost_model); + + return cost / range_size; +} + +int IRVisualization::get_combined_color_range(const IRNode *op, bool is_compcost) const { + if (op == nullptr) { + return 0; + } + + // divide max cost by NUMBER_COST_COLORS and round up to get range size + int cost, max_cost; + if (is_compcost) { + cost = find_stmt_cost.get_cost(op, StmtCostModel::Compute); + max_cost = find_stmt_cost.get_max_cost(StmtCostModel::ComputeInclusive); + } else { + cost = find_stmt_cost.get_cost(op, StmtCostModel::DataMovement); + max_cost = find_stmt_cost.get_max_cost(StmtCostModel::DataMovementInclusive); + } + + int range_size = (max_cost / NUMBER_COST_COLORS) + 1; + int range = cost / range_size; + + if (range >= NUMBER_COST_COLORS) { + range = NUMBER_COST_COLORS - 1; + } + + return range; +} + +void IRVisualization::start_module_traversal(const Module &m) { + + // print main function first + for (const auto &f : m.functions()) { + if (f.name == m.name()) { + visit_function(f); + } + } + + // print the rest of the functions + for (const auto &f : m.functions()) { + if (f.name != m.name()) { + visit_function(f); + } + } +} + +string IRVisualization::open_box_div(const string &class_name, const IRNode *op) { + ostringstream ss; + + ss << "
    "; + + if (op != nullptr) { + ss << generate_computation_cost_div(op); + ss << generate_memory_cost_div(op); + } + + ss << open_content_div(); + return ss.str(); +} +string IRVisualization::close_box_div() const { + ostringstream ss; + ss << close_div(); // body div (opened at end of each open_header_...() instance) + ss << close_div(); // content div + ss << close_div(); // main box div + return ss.str(); +} +string IRVisualization::open_function_box_div() const { + return "
    "; +} +string IRVisualization::close_function_box_div() const { + ostringstream ss; + ss << close_div(); // content div + ss << close_div(); // main box div + return ss.str(); +} +string IRVisualization::open_header_div() const { + return "
    "; +} +string IRVisualization::open_box_header_title_div() const { + return "
    "; +} +string IRVisualization::open_box_header_table_div() const { + return "
    "; +} +string IRVisualization::open_store_div() const { + return "
    "; +} +string IRVisualization::open_body_div() const { + ostringstream ss; + ss << "
    "; + return ss.str(); +} +string IRVisualization::close_div() const { + return "
    "; +} + +string IRVisualization::open_header(const string &header, const string &anchor_name, + vector> info_tooltip_table) { + ostringstream ss; + ss << open_header_div(); + + num_of_nodes++; + + // to make buttons next to each other + ss << "
    "; + + // collapse/expand buttons + ss << "
    "; + ss << ""; + ss << ""; + ss << "
    "; + + // see code button + ss << see_code_button_div(anchor_name); + + // info button + if (!info_tooltip_table.empty()) { + ss << "
    "; + ss << info_button_with_tooltip(tooltip_table(info_tooltip_table), + "iconButton dottedIconButton"); + ss << "
    "; + } + + ss << "
    "; // to make buttons next to each other + + ss << open_box_header_title_div(); + + ss << ""; + ss << header; + ss << ""; + + ss << close_div(); + + return ss.str(); +} +string IRVisualization::close_header() const { + return close_div(); +} +string IRVisualization::div_header(const string &header, StmtSize *size, const string &anchor_name, + vector> info_tooltip_table = {}) { + ostringstream ss; + + ss << open_header(header, anchor_name, std::move(info_tooltip_table)); + ss << close_header(); + + // add producer consumer size if size is provided + if (size != nullptr) { + ss << open_box_header_table_div(); + ss << read_write_table(*size); + ss << close_div(); + } + + // open body + ss << open_body_div(); + + return ss.str(); +} +string IRVisualization::function_div_header(const string &function_name, const string &anchor_name) const { + ostringstream ss; + + ss << "
    "; + + ss << ""; + ss << ""; + ss << "

    Func: " << function_name << "

    "; + ss << "
    "; + ss << "
    "; + + // see code button + ss << ""; + + ss << "
    "; + + return ss.str(); +} +vector IRVisualization::get_allocation_sizes(const Allocate *op) const { + vector sizes; + + ostringstream type; + type << "" << op->type << ""; + sizes.push_back(type.str()); + + for (const auto &extent : op->extents) { + ostringstream ss; + if (extent.as()) { + ss << "" << extent << ""; + } else { + ss << "" << extent << ""; + } + + sizes.push_back(ss.str()); + } + + internal_assert(sizes.size() == op->extents.size() + 1); + + return sizes; +} +string IRVisualization::allocate_div_header(const Allocate *op, const string &header, + const string &anchor_name, + vector> &info_tooltip_table) { + ostringstream ss; + + ss << open_header(header, anchor_name, info_tooltip_table); + ss << close_header(); + + vector allocation_sizes = get_allocation_sizes(op); + ss << open_box_header_table_div(); + ss << allocate_table(allocation_sizes); + ss << close_div(); + + // open body + ss << open_body_div(); + + return ss.str(); +} +string IRVisualization::for_loop_div_header(const For *op, const string &header, + const string &anchor_name) { + ostringstream ss; + + ss << open_header(header, anchor_name, {}); + ss << close_header(); + + string loopSize = get_loop_iterator(op); + ss << open_box_header_table_div(); + ss << for_loop_table(loopSize); + ss << close_div(); + + // open body + ss << open_body_div(); + + return ss.str(); +} + +string IRVisualization::if_tree(const IRNode *op, const string &header, const string &anchor_name) { + ostringstream ss; + + ss << "
  • "; + ss << ""; + + ss << open_box_div("IfBox", op); + ss << div_header(header, nullptr, anchor_name); + + return ss.str(); +} +string IRVisualization::close_if_tree() const { + ostringstream ss; + ss << close_box_div(); + ss << ""; + ss << "
  • "; + return ss.str(); +} + +string IRVisualization::read_write_table(StmtSize &size) const { + ostringstream read_write_table_ss; + + // open table + read_write_table_ss << ""; + + // Prod | Cons + read_write_table_ss << ""; + + read_write_table_ss << ""; + + read_write_table_ss << ""; + + read_write_table_ss << ""; + + // produces and consumes are empty + internal_assert(!size.empty()) << "\n\n" + << "IRVisualization::read_write_table - size is empty" + << "\n"; + + // produces and consumes aren't empty + if (!size.empty()) { + vector rows; + + // fill in producer variables + for (const auto &produce_var : size.writes) { + string ss; + ss += ""; + + ss += ""; + + rows.push_back(ss); + } + + // fill in consumer variables + unsigned long row_num = 0; + for (const auto &consume_var : size.reads) { + string ss; + ss += ""; + + ss += ""; + + if (row_num < rows.size()) { + rows[row_num] += ss; + } else { + // pad row with empty cells for produce + string s_empty; + s_empty += ""; + + rows.push_back(s_empty + ss); + } + row_num++; + } + + // pad row with empty calls for consume + row_num = size.reads.size(); + while (row_num < size.writes.size()) { + string s_empty; + s_empty += ""; + s_empty += ""; + + rows[row_num] += s_empty; + row_num++; + } + + // add rows to read_write_table_ss + for (const auto &row : rows) { + read_write_table_ss << ""; + read_write_table_ss << row; + read_write_table_ss << ""; + } + } + + // close table + read_write_table_ss << "
    "; + read_write_table_ss << "Written"; + read_write_table_ss << ""; + read_write_table_ss << "Read"; + read_write_table_ss << "
    "; + ss += produce_var.first + ": "; + ss += ""; + ss += produce_var.second; + ss += ""; + ss += consume_var.first + ": "; + ss += ""; + ss += consume_var.second; + ss += ""; + s_empty += ""; + s_empty += ""; + s_empty += "
    "; + + return read_write_table_ss.str(); +} +string IRVisualization::allocate_table(vector &allocation_sizes) const { + ostringstream allocate_table_ss; + + // open table + allocate_table_ss << ""; + + // open header and data rows + ostringstream header; + ostringstream data; + + header << ""; + data << ""; + + // iterate through all allocation sizes and add them to the header and data rows + for (unsigned long i = 0; i < allocation_sizes.size(); i++) { + if (i == 0) { + header << ""; + + data << ""; + } else { + if (i < allocation_sizes.size() - 1) { + header << ""; + } + } + + // close header and data rows + header << ""; + data << ""; + + // add header and data rows to allocate_table_ss + allocate_table_ss << header.str(); + allocate_table_ss << data.str(); + + // close table + allocate_table_ss << "
    "; + header << "Type"; + header << ""; + data << allocation_sizes[0]; + data << ""; + data << ""; + } else { + header << ""; + data << ""; + } + header << "Dim-" + std::to_string(i); + header << ""; + + data << allocation_sizes[i]; + data << "
    "; + + return allocate_table_ss.str(); +} +string IRVisualization::for_loop_table(const string &loop_size) const { + ostringstream for_loop_table_ss; + + // open table + for_loop_table_ss << ""; + + // Loop Size + for_loop_table_ss << ""; + + for_loop_table_ss << ""; + + for_loop_table_ss << ""; + + for_loop_table_ss << ""; + + // loop size + for_loop_table_ss << ""; + + for_loop_table_ss << ""; + + // close table + for_loop_table_ss << "
    "; + for_loop_table_ss << "Loop Span"; + for_loop_table_ss << "
    "; + for_loop_table_ss << loop_size; + for_loop_table_ss << "
    "; + + return for_loop_table_ss.str(); +} + +string IRVisualization::see_code_button_div(const string &anchor_name, bool put_div) const { + ostringstream ss; + if (put_div) { + ss << "
    "; + } + ss << ""; + if (put_div) { + ss << "
    "; + } + return ss.str(); +} + +string IRVisualization::info_button_with_tooltip(const string &tooltip_text, const string &button_class_name, + const string &tooltip_class_name) { + ostringstream ss; + + // infoButton + ir_viz_tooltip_count++; + ss << ""; + + // tooltip span + ss << ""; + ss << tooltip_text; + ss << ""; + + return ss.str(); +} + +string IRVisualization::generate_computation_cost_div(const IRNode *op) { + ostringstream ss; + + // skip if it's a store + if (op->node_type == IRNodeType::Store) { + return ""; + } + + ir_viz_tooltip_count++; + + string tooltip_text = generate_computation_cost_tooltip(op, ""); + + // tooltip span + ss << ""; + ss << tooltip_text; + ss << ""; + + int computation_range = get_color_range(op, StmtCostModel::ComputeInclusive); + string class_name = "computation-cost-div CostColor" + std::to_string(computation_range); + ss << "
    "; + + ss << close_div(); + + return ss.str(); +} +string IRVisualization::generate_memory_cost_div(const IRNode *op) { + ostringstream ss; + + // skip if it's a store + if (op->node_type == IRNodeType::Store) { + return ""; + } + + ir_viz_tooltip_count++; + + string tooltip_text = generate_data_movement_cost_tooltip(op, ""); + + // tooltip span + ss << ""; + ss << tooltip_text; + ss << ""; + + int data_movement_range = get_color_range(op, StmtCostModel::DataMovementInclusive); + string class_name = "memory-cost-div CostColor" + std::to_string(data_movement_range); + ss << "
    "; + + ss << close_div(); + + return ss.str(); +} +string IRVisualization::open_content_div() const { + return "
    "; +} + +int IRVisualization::get_cost_percentage(const IRNode *node, StmtCostModel cost_model) const { + int cost; + if (node == nullptr) { + cost = StmtCost::NormalNodeCC; + } else { + cost = find_stmt_cost.get_cost(node, cost_model); + } + + int total_cost; + + switch (cost_model) { + case StmtCostModel::Compute: + total_cost = find_stmt_cost.get_max_cost(StmtCostModel::ComputeInclusive); + break; + case StmtCostModel::DataMovement: + total_cost = find_stmt_cost.get_max_cost(StmtCostModel::DataMovementInclusive); + break; + default: + total_cost = find_stmt_cost.get_max_cost(cost_model); + } + + return (int)((float)cost / (float)total_cost * 100); +} + +string IRVisualization::tooltip_table(vector> &table, const string &extra_note) { + ostringstream s; + s << ""; + for (auto &row : table) { + s << ""; + s << ""; + s << ""; + s << ""; + } + s << "
    " << row.first << " " << row.second << "
    "; + + if (!extra_note.empty()) { + s << "" << extra_note << ""; + } + return s.str(); +} + +string IRVisualization::IRVisualization::color_button(int color_range) { + ostringstream ss; + + ir_viz_tooltip_count++; + ss << ""; + + return ss.str(); +} +string IRVisualization::computation_div(const IRNode *op) { + // want exclusive cost (so that the colors match up with exclusive costs) + int computation_range = get_color_range(op, StmtCostModel::Compute); + + ostringstream ss; + ss << color_button(computation_range); + + string tooltip_text = generate_computation_cost_tooltip(op, ""); + + // tooltip span + ss << ""; + ss << tooltip_text; + ss << ""; + + return ss.str(); +} +string IRVisualization::data_movement_div(const IRNode *op) { + // want exclusive cost (so that the colors match up with exclusive costs) + int data_movement_range = get_color_range(op, StmtCostModel::DataMovement); + + ostringstream ss; + ss << color_button(data_movement_range); + + string tooltip_text = generate_data_movement_cost_tooltip(op, ""); + + // tooltip span + ss << ""; + ss << tooltip_text; + ss << ""; + + return ss.str(); +} +string IRVisualization::cost_colors(const IRNode *op) { + ostringstream ss; + ss << computation_div(op); + ss << data_movement_div(op); + return ss.str(); +} + +void IRVisualization::visit_function(const LoweredFunc &func) { + html << open_function_box_div(); + + function_count++; + string anchor_name = "loweredFunc" + std::to_string(function_count); + + html << function_div_header(func.name, anchor_name); + + html << "
    "; + func.body.accept(this); + html << "
    "; + + html << close_function_box_div(); +} +void IRVisualization::visit(const Variable *op) { + // if op->name starts with "::", remove "::" + if (op->name.size() < 2) { + return; + } + string var_name = op->name; + if (var_name[0] == ':' && var_name[1] == ':') { + var_name = var_name.substr(2); + } + + // see if var_name is in get_read_write.function_names + if (std::count(get_read_write.function_names.begin(), get_read_write.function_names.end(), + var_name)) { + + html << "
    "; + + html << "Function Call"; + html << ""; + + html << "
    "; + } +} +void IRVisualization::visit(const ProducerConsumer *op) { + html << open_box_div("ProducerConsumerBox", op); + + producer_consumer_count++; + string anchor_name = "producerConsumer" + std::to_string(producer_consumer_count); + + string header = (op->is_producer ? "Produce" : "Consume"); + header += " " + op->name; + + html << div_header(header, nullptr, anchor_name); + + op->body.accept(this); + + html << close_box_div(); +} +string IRVisualization::get_loop_iterator_binary(const IRNodeType &type, const Expr &a, + const Expr &b) const { + ostringstream extent_name; + extent_name << "("; + + // deal with a + if (a.node_type() == IRNodeType::IntImm) { + int64_t extent_value = a.as()->value; + extent_name << get_read_write.int_span(extent_value); + } else if (a.node_type() == IRNodeType::Variable) { + extent_name << get_read_write.string_span(a.as()->name); + } else { + extent_name << a; + } + + // operator + if (type == IRNodeType::Add) { + extent_name << " + "; + } else if (type == IRNodeType::Sub) { + extent_name << " - "; + } else if (type == IRNodeType::Mul) { + extent_name << " * "; + } else if (type == IRNodeType::Div) { + extent_name << " / "; + } else if (type == IRNodeType::Mod) { + extent_name << " % "; + } else { + internal_assert(false) << "Unknown IRNodeType: \n"; + } + + // deal with b + if (b.node_type() == IRNodeType::IntImm) { + int64_t extent_value = b.as()->value; + extent_name << get_read_write.int_span(extent_value); + } else if (b.node_type() == IRNodeType::Variable) { + extent_name << get_read_write.string_span(b.as()->name); + } else { + extent_name << b; + } + + extent_name << ")"; + + return extent_name.str(); +} +string IRVisualization::get_loop_iterator(const For *op) const { + Expr min = op->min; + Expr extent = op->extent; + + string loop_iterator; + + // if min is IntImm + if (min.node_type() == IRNodeType::IntImm) { + int64_t min_value = min.as()->value; + + // extent is IntImm + if (extent.node_type() == IRNodeType::IntImm) { + int64_t extent_value = extent.as()->value; + uint16_t range = uint16_t(extent_value - min_value); + loop_iterator = get_read_write.int_span(range); + } + + // extent is Variable + else if (extent.node_type() == IRNodeType::Variable) { + // int64_t min_value = min.as()->value; + string min_name = get_read_write.int_span(min_value); + string extent_name = get_read_write.string_span(extent.as()->name); + + if (min_value == 0) { + loop_iterator = extent_name; + } else { + loop_iterator = "(" + extent_name + " - " + min_name + ")"; + } + } + + // extent is binary op + else if (extent.node_type() == IRNodeType::Add) { + string min_name = get_read_write.int_span(min_value); + string extent_name = + get_loop_iterator_binary(IRNodeType::Add, extent.as()->a, extent.as()->b); + + if (min_value == 0) { + loop_iterator = extent_name; + } else { + loop_iterator = "(" + extent_name + " - " + min_name + ")"; + } + } else if (extent.node_type() == IRNodeType::Sub) { + string min_name = get_read_write.int_span(min_value); + string extent_name = + get_loop_iterator_binary(IRNodeType::Sub, extent.as()->a, extent.as()->b); + + if (min_value == 0) { + loop_iterator = extent_name; + } else { + loop_iterator = "(" + extent_name + " - " + min_name + ")"; + } + } else if (extent.node_type() == IRNodeType::Mul) { + string min_name = get_read_write.int_span(min_value); + string extent_name = + get_loop_iterator_binary(IRNodeType::Mul, extent.as()->a, extent.as()->b); + + if (min_value == 0) { + loop_iterator = extent_name; + } else { + loop_iterator = "(" + extent_name + " - " + min_name + ")"; + } + } else if (extent.node_type() == IRNodeType::Div) { + string min_name = get_read_write.int_span(min_value); + string extent_name = + get_loop_iterator_binary(IRNodeType::Div, extent.as
    ()->a, extent.as
    ()->b); + + if (min_value == 0) { + loop_iterator = extent_name; + } else { + loop_iterator = "(" + extent_name + " - " + min_name + ")"; + } + } else if (extent.node_type() == IRNodeType::Mod) { + string min_name = get_read_write.int_span(min_value); + string extent_name = + get_loop_iterator_binary(IRNodeType::Mod, extent.as()->a, extent.as()->b); + + if (min_value == 0) { + loop_iterator = extent_name; + } else { + loop_iterator = "(" + extent_name + " - " + min_name + ")"; + } + } + + // extent is something else + else { + ostringstream loop_it; + if (min_value == 0) { + loop_it << op->extent; + } else { + loop_it << "(" << op->extent << ") - (" << op->min << ")"; + } + loop_iterator = loop_it.str(); + } + } + + // min is not an IntImm + else { + ostringstream loop_it; + loop_it << "(" << op->extent << ") - (" << op->min << ")"; + loop_iterator = loop_it.str(); + } + + return loop_iterator; +} +void IRVisualization::visit(const For *op) { + html << open_box_div("ForBox", op); + + for_count++; + string anchor_name = "for" + std::to_string(for_count); + + string header = "For (" + op->name + ")"; + + html << for_loop_div_header(op, header, anchor_name); + + op->body.accept(this); + + html << close_box_div(); +} +void IRVisualization::visit(const IfThenElse *op) { + // open main if tree + html << "
    "; + html << "
      "; + html << "
    • "; + html << "If"; + html << ""; + html << "
        "; + + string if_header; + if_header += "if "; + + // anchor name + if_count++; + string anchor_name = "if" + std::to_string(if_count); + + while (true) { + ostringstream condition; + condition << op->condition; + + string condition_string = condition.str(); + // make condition smaller if it's too big + constexpr size_t max_condition_legth = 25; + if (condition_string.size() > max_condition_legth) { + condition.str(""); + condition << "(..."; + condition << info_button_with_tooltip("condition:
        " + condition_string, "", + "conditionTooltip"); + condition << ")"; + } + + if_header += condition.str(); + + html << if_tree(op, if_header, anchor_name); + + // then body + op->then_case.accept(this); + + html << close_if_tree(); + + // if there is no else case, we are done + if (!op->else_case.defined()) { + break; + } + + // if else case is another ifthenelse, recurse and reset op to else case + if (const IfThenElse *nested_if = op->else_case.as()) { + op = nested_if; + if_header = ""; + if_header += "else if "; + + // anchor name + if_count++; + anchor_name = "if" + std::to_string(if_count); + + } + + // if else case is not another ifthenelse + else { + + string else_header; + else_header += "else "; + + // anchor name + if_count++; + anchor_name = "if" + std::to_string(if_count); + + html << if_tree(op->else_case.get(), else_header, anchor_name); + + op->else_case.accept(this); + + html << close_if_tree(); + break; + } + } + + // close main if tree + html << "
      "; + html << "
    • "; + html << "
    "; + html << "
    "; +} +void IRVisualization::visit(const Store *op) { + StmtSize size = get_read_write.get_size(op); + + store_count++; + string anchor_name = "store" + std::to_string(store_count); + + string header = "Store " + op->name; + + vector> table_rows; + table_rows.emplace_back("Vector Size", std::to_string(op->index.type().lanes())); + table_rows.emplace_back("Bit Size", std::to_string(op->index.type().bits())); + + html << open_box_div("StoreBox", op); + + html << div_header(header, &size, anchor_name, table_rows); + + op->value.accept(this); + + html << close_box_div(); +} +void IRVisualization::visit(const Load *op) { + string header; + vector> table_rows; + + if (op->type.is_scalar()) { + header = "[Scalar] "; + } + + else if (op->type.is_vector()) { + if (op->index.node_type() == IRNodeType::Ramp) { + const Ramp *ramp = op->index.as(); + + table_rows.emplace_back("Ramp lanes", std::to_string(ramp->lanes)); + ostringstream ramp_stride; + ramp_stride << ramp->stride; + table_rows.emplace_back("Ramp stride", ramp_stride.str()); + + if (ramp->stride.node_type() == IRNodeType::IntImm) { + int64_t stride = ramp->stride.as()->value; + if (stride == 1) { + header = "[Dense, Vector] "; + } else { + header = "[Strided, Vector] "; + } + } else { + header = "[Dense, Vector] "; + } + } else { + header = "[Dense, Vector] "; + } + } + + else { + internal_assert(false) << "\n\nUnsupported type for Load: " << op->type << "\n\n"; + } + + header += "Load " + op->name + ""; + + table_rows.emplace_back("Bit Size", std::to_string(op->index.type().bits())); + table_rows.emplace_back("Vector Size", std::to_string(op->index.type().lanes())); + + if (op->param.defined()) { + table_rows.emplace_back("Parameter", op->param.name()); + } + + header += info_button_with_tooltip(tooltip_table(table_rows), ""); + + html << open_store_div(); + html << cost_colors(op); + html << header; + html << close_div(); +} +string IRVisualization::get_memory_type(MemoryType mem_type) const { + if (mem_type == MemoryType::Auto) { + return "Auto"; + } else if (mem_type == MemoryType::Heap) { + return "Heap"; + } else if (mem_type == MemoryType::Stack) { + return "Stack"; + } else if (mem_type == MemoryType::Register) { + return "Register"; + } else if (mem_type == MemoryType::GPUShared) { + return "GPUShared"; + } else if (mem_type == MemoryType::GPUTexture) { + return "GPUTexture"; + } else if (mem_type == MemoryType::LockedCache) { + return "LockedCache"; + } else if (mem_type == MemoryType::VTCM) { + return "VTCM"; + } else if (mem_type == MemoryType::AMXTile) { + return "AMXTile"; + } else { + internal_assert(false) << "\n\n" + << "Unknown memory type" + << "\n"; + return "Unknown Memory Type"; + } +} +void IRVisualization::visit(const Allocate *op) { + html << open_box_div("AllocateBox", op); + + allocate_count++; + string anchor_name = "allocate" + std::to_string(allocate_count); + + string header = "Allocate " + op->name; + + vector> table_rows; + table_rows.emplace_back("Memory Type", get_memory_type(op->memory_type)); + + if (!is_const_one(op->condition)) { + ostringstream condition_string; + condition_string << op->condition; + table_rows.emplace_back("Condition", condition_string.str()); + } + if (op->new_expr.defined()) { + ostringstream new_expr_string; + new_expr_string << op->new_expr; + table_rows.emplace_back("New Expr", new_expr_string.str()); + } + if (!op->free_function.empty()) { + + ostringstream free_func_string; + free_func_string << op->free_function; + table_rows.emplace_back("Free Function", free_func_string.str()); + } + + table_rows.emplace_back("Bit Size", std::to_string(op->type.bits())); + table_rows.emplace_back("Vector Size", std::to_string(op->type.lanes())); + + html << allocate_div_header(op, header, anchor_name, table_rows); + + op->body.accept(this); + + html << close_box_div(); +} + +string IRVisualization::generate_ir_visualization_js() { + ostringstream ir_viz_js; + + ir_viz_js << "\n// irViz JS\n" + << "for (let i = 1; i <= " << ir_viz_tooltip_count << "; i++) { \n" + << " const button = document.getElementById('irVizButton' + i); \n" + << " const tooltip = document.getElementById('irVizTooltip' + i); \n" + << " button.addEventListener('mouseenter', () => { \n" + << " showTooltip(button, tooltip); \n" + << " }); \n" + << " button.addEventListener('mouseleave', () => { \n" + << " hideTooltip(tooltip); \n" + << " } \n" + << " ); \n" + << " tooltip.addEventListener('focus', () => { \n" + << " showTooltip(button, tooltip); \n" + << " } \n" + << " ); \n" + << " tooltip.addEventListener('blur', () => { \n" + << " hideTooltip(tooltip); \n" + << " } \n" + << " ); \n" + << "} \n" + << "function toggleCollapse(id) {\n " + << " var buttonShow = document.getElementById('irViz' + id + '-show');\n" + << " var buttonHide = document.getElementById('irViz' + id + '-hide');\n" + << " var body = document.getElementById('irViz' + id);\n" + << " if (body.style.visibility != 'hidden') {\n" + << " body.style.visibility = 'hidden';\n" + << " body.style.height = '0px';\n" + << " body.style.width = '0px';\n" + << " buttonShow.style.display = 'block';\n" + << " buttonHide.style.display = 'none';\n" + << " } else {\n" + << " body.style = '';\n" + << " buttonShow.style.display = 'none';\n" + << " buttonHide.style.display = 'block';\n" + << " }\n" + << "}\n "; + + return ir_viz_js.str(); +} + +/* + * PRINT NODE + */ +string GetReadWrite::print_node(const IRNode *node) const { + ostringstream ss; + ss << "Node in question has type: "; + IRNodeType type = node->node_type; + if (type == IRNodeType::IntImm) { + ss << "IntImm type"; + const auto *node1 = dynamic_cast(node); + ss << "value: " << node1->value; + } else if (type == IRNodeType::UIntImm) { + ss << "UIntImm type"; + } else if (type == IRNodeType::FloatImm) { + ss << "FloatImm type"; + } else if (type == IRNodeType::StringImm) { + ss << "StringImm type"; + } else if (type == IRNodeType::Broadcast) { + ss << "Broadcast type"; + } else if (type == IRNodeType::Cast) { + ss << "Cast type"; + } else if (type == IRNodeType::Variable) { + ss << "Variable type"; + } else if (type == IRNodeType::Add) { + ss << "Add type"; + const auto *node1 = dynamic_cast(node); + ss << "a: " << print_node(node1->a.get()) << "\n"; + ss << "b: " << print_node(node1->b.get()) << "\n"; + } else if (type == IRNodeType::Sub) { + ss << "Sub type" + << "\n"; + const auto *node1 = dynamic_cast(node); + ss << "a: " << print_node(node1->a.get()) << "\n"; + ss << "b: " << print_node(node1->b.get()) << "\n"; + } else if (type == IRNodeType::Mod) { + ss << "Mod type" + << "\n"; + const auto *node1 = dynamic_cast(node); + ss << "a: " << print_node(node1->a.get()) << "\n"; + ss << "b: " << print_node(node1->b.get()) << "\n"; + } else if (type == IRNodeType::Mul) { + ss << "Mul type" + << "\n"; + const auto *node1 = dynamic_cast(node); + ss << "a: " << print_node(node1->a.get()) << "\n"; + ss << "b: " << print_node(node1->b.get()) << "\n"; + } else if (type == IRNodeType::Div) { + ss << "Div type" + << "\n"; + const auto *node1 = dynamic_cast(node); + ss << "a: " << print_node(node1->a.get()) << "\n"; + ss << "b: " << print_node(node1->b.get()) << "\n"; + } else if (type == IRNodeType::Min) { + ss << "Min type"; + } else if (type == IRNodeType::Max) { + ss << "Max type"; + } else if (type == IRNodeType::EQ) { + ss << "EQ type"; + } else if (type == IRNodeType::NE) { + ss << "NE type"; + } else if (type == IRNodeType::LT) { + ss << "LT type"; + } else if (type == IRNodeType::LE) { + ss << "LE type"; + } else if (type == IRNodeType::GT) { + ss << "GT type"; + } else if (type == IRNodeType::GE) { + ss << "GE type"; + } else if (type == IRNodeType::And) { + ss << "And type"; + } else if (type == IRNodeType::Or) { + ss << "Or type"; + } else if (type == IRNodeType::Not) { + ss << "Not type"; + } else if (type == IRNodeType::Select) { + ss << "Select type"; + } else if (type == IRNodeType::Load) { + ss << "Load type"; + } else if (type == IRNodeType::Ramp) { + ss << "Ramp type"; + } else if (type == IRNodeType::Call) { + ss << "Call type"; + } else if (type == IRNodeType::Let) { + ss << "Let type"; + } else if (type == IRNodeType::Shuffle) { + ss << "Shuffle type"; + } else if (type == IRNodeType::VectorReduce) { + ss << "VectorReduce type"; + } else if (type == IRNodeType::LetStmt) { + ss << "LetStmt type"; + } else if (type == IRNodeType::AssertStmt) { + ss << "AssertStmt type"; + } else if (type == IRNodeType::ProducerConsumer) { + ss << "ProducerConsumer type"; + } else if (type == IRNodeType::For) { + ss << "For type"; + } else if (type == IRNodeType::Acquire) { + ss << "Acquire type"; + } else if (type == IRNodeType::Store) { + ss << "Store type"; + } else if (type == IRNodeType::Provide) { + ss << "Provide type"; + } else if (type == IRNodeType::Allocate) { + ss << "Allocate type"; + } else if (type == IRNodeType::Free) { + ss << "Free type"; + } else if (type == IRNodeType::Realize) { + ss << "Realize type"; + } else if (type == IRNodeType::Block) { + ss << "Block type"; + } else if (type == IRNodeType::Fork) { + ss << "Fork type"; + } else if (type == IRNodeType::IfThenElse) { + ss << "IfThenElse type"; + } else if (type == IRNodeType::Evaluate) { + ss << "Evaluate type"; + } else if (type == IRNodeType::Prefetch) { + ss << "Prefetch type"; + } else if (type == IRNodeType::Atomic) { + ss << "Atomic type"; + } else { + ss << "Unknown type"; + } + + return ss.str(); +} + +const string IRVisualization::scroll_to_function_JS_viz_to_code = "\n \ +// scroll to function - viz to code\n \ +function makeVisible(element) { \n \ + if (!element) return; \n \ + if (element.className == 'mainContent') return; \n \ + if (element.style.visibility == 'hidden') { \n \ + element.style = ''; \n \ + show = document.getElementById(element.id + '-show'); \n \ + hide = document.getElementById(element.id + '-hide'); \n \ + show.style.display = 'none'; \n \ + hide.style.display = 'block'; \n \ + return; \n \ + } \n \ + makeVisible(element.parentNode); \n \ +} \n \ + \n \ +function scrollToFunctionVizToCode(id) { \n \ + var container = document.getElementById('IRCode-code'); \n \ + var scrollToObject = document.getElementById(id); \n \ + makeVisible(scrollToObject); \n \ + container.scrollTo({ \n \ + top: scrollToObject.offsetTop - 10, \n \ + behavior: 'smooth' \n \ + }); \n \ + scrollToObject.style.backgroundColor = 'yellow'; \n \ + scrollToObject.style.fontSize = '20px'; \n \ + \n \ + // change content for 1 second \n \ + setTimeout(function () { \n \ + scrollToObject.style.backgroundColor = 'transparent'; \n \ + scrollToObject.style.fontSize = '12px'; \n \ + }, 1000); \n \ +} \n \ +"; + +const string IRVisualization::ir_viz_CSS = "\n \ +/* irViz CSS */\n \ +.tf-custom-irViz .tf-nc { border-radius: 5px; border: 1px solid; }\n \ +.tf-custom-irViz .tf-nc:before, .tf-custom-irViz .tf-nc:after { border-left-width: 1px; }\n \ +.tf-custom-irViz li li:before { border-top-width: 1px; }\n \ +.tf-custom-irViz .end-node { border-style: dashed; }\n \ +.tf-custom-irViz .tf-nc { background-color: #e6eeff; }\n \ +.tf-custom-irViz { font-size: 12px; } \n \ +div.box { \n \ + border: 1px dashed grey; \n \ + border-radius: 5px; \n \ + margin: 5px; \n \ + padding: 5px; \n \ + display: flex; \n \ + width: max-content; \n \ +} \n \ +div.boxHeader { \n \ + padding: 5px; \n \ + display: flex; \n \ +} \n \ +div.memory-cost-div, \n \ +div.computation-cost-div { \n \ + border: 1px solid rgba(0, 0, 0, 0); \n \ + width: 7px; \n \ +} \n \ +div.FunctionCallBox { \n \ + background-color: #fabebe; \n \ +} \n \ +div.FunctionBox { \n \ + background-color: #f0f0f0; \n \ + border: 1px dashed grey; \n \ + border-radius: 5px; \n \ + margin-bottom: 15px; \n \ + padding: 5px; \n \ + width: max-content; \n \ +} \n \ +div.functionHeader { \n \ + display: flex; \n \ + margin-bottom: 10px; \n \ +} \n \ +div.ProducerConsumerBox { \n \ + background-color: #99bbff; \n \ +} \n \ +div.ForBox { \n \ + background-color: #b3ccff; \n \ +} \n \ +div.StoreBox { \n \ + background-color: #f4f8bf; \n \ +} \n \ +div.AllocateBox { \n \ + background-color: #f4f8bf; \n \ +} \n \ +div.IfBox { \n \ + background-color: #e6eeff; \n \ +} \n \ +div.memory-cost-div:hover, \n \ +div.computation-cost-div:hover { \n \ + border: 1px solid grey; \n \ +} \n \ +div.boxBody { \n \ + margin-left: 5px; \n \ +} \n \ +div.boxHeaderTable { \n \ + padding-left: 5px; \n \ + padding-bottom: 5px; \n \ +} \n \ +table { \n \ + border-radius: 5px; \n \ + font-size: 12px; \n \ + border: 1px dashed grey; \n \ + border-collapse: separate; \n \ + border-spacing: 0; \n \ +} \n \ +.ifElseTable { \n \ + border: 0px; \n \ +} \n \ +.costTable { \n \ + text-align: center; \n \ + border: 0px; \n \ + background-color: rgba(150, 150, 150, 0.2); \n \ +} \n \ +.costTable td { \n \ + border-top: 1px dashed grey; \n \ +} \n \ +.costTableHeader, \n \ +.costTableData { \n \ + border-collapse: collapse; \n \ + padding-top: 3px; \n \ + padding-bottom: 3px; \n \ + padding-left: 7px; \n \ + padding-right: 7px; \n \ +} \n \ +span.intType { color: #099; } \n \ +span.stringType { color: #990073; } \n \ +.middleCol { \n \ + border-right: 1px dashed grey; \n \ +} \n \ +div.content { \n \ + flex-grow: 1; \n \ +} \n \ +.irVizColorButton { \n \ + height: 15px; \n \ + width: 10px; \n \ + margin-right: 2px; \n \ + border: 1px solid rgba(0, 0, 0, 0); \n \ + vertical-align: middle; \n \ + border-radius: 2px; \n \ +} \n \ +.irVizColorButton:hover { \n \ + border: 1px solid grey; \n \ +} \n \ +div.boxHeaderTitle { \n \ + font-weight: bold; \n \ + margin-top: auto; \n \ + margin-bottom: auto; \n \ +} \n \ +.irVizToggle { \n \ + margin-right: 5px; \n \ + margin-left: 0px; \n \ +} \n \ +.dottedIconButton { \n \ + border: 1px dotted black; \n \ + border-radius: 3px; \n \ +} \n \ +.dottedIconButton:hover { \n \ + border: 1px dotted red; \n \ +} \n \ +.functionButton { \n \ + background-color: #fff; \n \ + border: 1px solid #d5d9d9; \n \ + border-radius: 8px; \n \ + box-shadow: rgba(213, 217, 217, .5) 0 2px 5px 0; \n \ + position: relative; \n \ + text-align: center; \n \ + vertical-align: middle; \n \ + margin-left: 5px; \n \ + font-size: 15px; \n \ + padding: 3px; \n \ +} \n \ +.functionButton:hover { \n \ + background-color: #f7fafa; \n \ +} \n \ +"; + +} // namespace Internal +} // namespace Halide diff --git a/src/IRVisualization.h b/src/IRVisualization.h new file mode 100644 index 000000000000..8d7eb9d6f4bf --- /dev/null +++ b/src/IRVisualization.h @@ -0,0 +1,186 @@ +#ifndef HALIDE_IR_VISUALIZATION_H +#define HALIDE_IR_VISUALIZATION_H + +#include "FindStmtCost.h" +#include "IRVisitor.h" + +#include +#include + +namespace Halide { +namespace Internal { + +struct StmtSize { + std::map writes; + std::map reads; + + bool empty() const { + return writes.empty() && reads.empty(); + } +}; + +/* + * GetReadWrite class + */ +class GetReadWrite : public IRVisitor { +public: + std::vector function_names; // used for figuring out whether variable is a function call + + // generates the reads/writes for the module + void generate_sizes(const Module &m); + + // returns the reads/writes for the given node + StmtSize get_size(const IRNode *node) const; + + // for coloring + std::string string_span(const std::string &var_name) const; + std::string int_span(int64_t int_val) const; + + // prints nodes in error messages + std::string print_node(const IRNode *node) const; + +private: + using IRVisitor::visit; + + std::unordered_map stmt_sizes; // stores the sizes + std::map curr_load_values; // used when calculating store reads + + // starts traversal of the module + void traverse(const Module &m); + + // used to simplify expressions with + and *, to not have too many parentheses + std::string get_simplified_string(const std::string &a, const std::string &b, const std::string &op); + + // sets reads/writes for the given node + void set_write_size(const IRNode *node, const std::string &write_var, std::string write_size); + void set_read_size(const IRNode *node, const std::string &read_var, std::string read_size); + + void visit(const Store *op) override; + void add_load_value(const std::string &name, int lanes); + void visit(const Load *op) override; +}; + +/* + * IRVisualization class + */ +class IRVisualization : public IRVisitor { + +public: + static const std::string ir_viz_CSS, scroll_to_function_JS_viz_to_code; + + IRVisualization(FindStmtCost find_stmt_cost_populated) + : find_stmt_cost(std::move(find_stmt_cost_populated)), ir_viz_tooltip_count(0), if_count(0), + producer_consumer_count(0), for_count(0), store_count(0), allocate_count(0), + function_count(0) { + } + + // generates the html for the IR Visualization + std::string generate_ir_visualization_html(const Module &m); + + // returns the JS for the IR Visualization + std::string generate_ir_visualization_js(); + + // generates tooltip tables based on given node + std::string generate_computation_cost_tooltip(const IRNode *op, const std::string &extraNote); + std::string generate_data_movement_cost_tooltip(const IRNode *op, const std::string &extraNote); + + // returns the range of the node's cost based on the other nodes' costs + int get_color_range(const IRNode *op, StmtCostModel cost_model) const; + + // returns color range when blocks are collapsed in code viz + int get_combined_color_range(const IRNode *op, bool is_compcost) const; + +private: + using IRVisitor::visit; + + std::ostringstream html; // main html string + GetReadWrite get_read_write; // generates the read/write sizes + FindStmtCost find_stmt_cost; // used to determine the color of each statement + int num_of_nodes; // keeps track of the number of nodes in the visualization + int ir_viz_tooltip_count; // tooltip count + + // used for getting anchor names + int if_count; + int producer_consumer_count; + int for_count; + int store_count; + int allocate_count; + int function_count; + + // for traversal of a Module object + void start_module_traversal(const Module &m); + + // opens and closes divs + std::string open_box_div(const std::string &class_name, const IRNode *op); + std::string close_box_div() const; + std::string open_function_box_div() const; + std::string close_function_box_div() const; + std::string open_header_div() const; + std::string open_box_header_title_div() const; + std::string open_box_header_table_div() const; + std::string open_store_div() const; + std::string open_body_div() const; + std::string close_div() const; + + // header functions + std::string open_header(const std::string &header, const std::string &anchor_name, + std::vector> info_tooltip_table); + std::string close_header() const; + std::string div_header(const std::string &header, StmtSize *size, const std::string &anchor_name, + std::vector> info_tooltip_table); + std::string function_div_header(const std::string &function_name, const std::string &anchor_name) const; + std::vector get_allocation_sizes(const Allocate *op) const; + std::string allocate_div_header(const Allocate *op, const std::string &header, const std::string &anchor_name, + std::vector> &info_tooltip_table); + std::string for_loop_div_header(const For *op, const std::string &header, const std::string &anchor_name); + + // opens and closes an if-tree + std::string if_tree(const IRNode *op, const std::string &header, const std::string &anchor_name); + std::string close_if_tree() const; + + // different cost tables + std::string read_write_table(StmtSize &size) const; + std::string allocate_table(std::vector &allocation_sizes) const; + std::string for_loop_table(const std::string &loop_size) const; + + // generates code for button that will scroll to associated IR code line + std::string see_code_button_div(const std::string &anchor_name, bool put_div = true) const; + + // info button with tooltip + std::string info_button_with_tooltip(const std::string &tooltip_text, const std::string &button_class_name, + const std::string &tooltip_class_name = ""); + + // for cost colors - side bars of boxes + std::string generate_computation_cost_div(const IRNode *op); + std::string generate_memory_cost_div(const IRNode *op); + std::string open_content_div() const; + + // gets cost percentages of a given node + int get_cost_percentage(const IRNode *node, StmtCostModel cost_model) const; + + // builds the tooltip cost table based on given input table + std::string tooltip_table(std::vector> &table, const std::string &extra_note = ""); + + // for cost colors - side boxes of Load nodes + std::string color_button(int color_range); + std::string computation_div(const IRNode *op); + std::string data_movement_div(const IRNode *op); + std::string cost_colors(const IRNode *op); + + void visit_function(const LoweredFunc &func); + void visit(const Variable *op) override; + void visit(const ProducerConsumer *op) override; + std::string get_loop_iterator_binary(const IRNodeType &type, const Expr &a, const Expr &b) const; + std::string get_loop_iterator(const For *op) const; + void visit(const For *op) override; + void visit(const IfThenElse *op) override; + void visit(const Store *op) override; + void visit(const Load *op) override; + std::string get_memory_type(MemoryType mem_type) const; + void visit(const Allocate *op) override; +}; + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/Module.cpp b/src/Module.cpp index cbb31604b7e4..89277be30ea7 100644 --- a/src/Module.cpp +++ b/src/Module.cpp @@ -19,6 +19,7 @@ #include "Pipeline.h" #include "PythonExtensionGen.h" #include "StmtToHtml.h" +#include "StmtToViz.h" namespace Halide { namespace Internal { @@ -48,6 +49,7 @@ std::map get_output_info(const Target &target) {OutputFileType::static_library, {"static_library", is_windows_coff ? ".lib" : ".a", IsSingle}}, {OutputFileType::stmt, {"stmt", ".stmt", IsMulti}}, {OutputFileType::stmt_html, {"stmt_html", ".stmt.html", IsMulti}}, + {OutputFileType::stmt_viz, {"stmt_viz", ".stmt.viz.html", IsMulti}}, }; return ext; } @@ -598,6 +600,12 @@ void Module::compile(const std::map &output_files) compile_llvm_module_to_llvm_assembly(*llvm_module, *out); } } + + if (contains(output_files, OutputFileType::stmt_viz)) { + debug(1) << "Module.compile(): stmt_viz " << output_files.at(OutputFileType::stmt_viz) << "\n"; + Internal::print_to_viz(output_files.at(OutputFileType::stmt_viz), *this); + } + if (contains(output_files, OutputFileType::c_header)) { debug(1) << "Module.compile(): c_header " << output_files.at(OutputFileType::c_header) << "\n"; std::ofstream file(output_files.at(OutputFileType::c_header)); diff --git a/src/Module.h b/src/Module.h index a1dbff31c345..46b492325f7a 100644 --- a/src/Module.h +++ b/src/Module.h @@ -40,6 +40,7 @@ enum class OutputFileType { static_library, stmt, stmt_html, + stmt_viz, }; /** Type of linkage a function in a lowered Halide module can have. diff --git a/src/StmtToViz.cpp b/src/StmtToViz.cpp new file mode 100644 index 000000000000..a741c2e1e21d --- /dev/null +++ b/src/StmtToViz.cpp @@ -0,0 +1,2377 @@ +#include "StmtToViz.h" +#include "Debug.h" +#include "Error.h" +#include "FindStmtCost.h" +#include "GetAssemblyInfoViz.h" +#include "GetStmtHierarchy.h" +#include "IROperator.h" +#include "IRVisitor.h" +#include "IRVisualization.h" +#include "Module.h" +#include "Scope.h" +#include "Substitute.h" +#include "Util.h" + +#include +#include +#include +#include +#include +#include + +namespace Halide { +namespace Internal { + +using std::ostringstream; +using std::string; + +const char *StmtToViz_canIgnoreVariableName_string = "canIgnoreVariableName"; + +class StmtToViz : public IRVisitor { + + // CSS strings + static const string ir_code_css, code_viz_css, cost_colors_css, flexbox_div_css, + line_numbers_css, code_mirror_css, tooltip_css; + + // JS strings + static const string ir_code_js, scroll_to_function_code_to_viz_js, expand_code_js, + code_mirror_js; + + // This allows easier access to individual elements. + int id_count; + +private: + std::ofstream stream; + + FindStmtCost find_stmt_cost; // used for finding the cost of statements + GetStmtHierarchy get_stmt_hierarchy; // used for getting the hierarchy of + // statements + IRVisualization ir_visualization; // used for getting the IR visualization + GetAssemblyInfoViz get_assembly_info_viz; // used for getting the assembly line numbers + + int curr_line_num; // for accessing div of that line + + // used for getting anchor names + int if_count; + int producer_consumer_count; + int for_count; + int store_count; + int allocate_count; + int functionCount; + + // used for tooltip stuff + int tooltip_count; + + // used for get_stmt_hierarchy popup stuff + int popup_count; + string popups; + + int unique_id() { + return ++id_count; + } + + // All spans and divs will have an id of the form "x-y", where x + // is shared among all spans/divs in the same context, and y is unique. + std::vector context_stack; + std::vector context_stack_tags; + string open_tag(const string &tag, const string &cls, int id = -1) { + ostringstream s; + s << "<" << tag << " class='" << cls << "' id='"; + if (id == -1) { + s << context_stack.back() << "-"; + s << unique_id(); + } else { + s << id; + } + s << "'>"; + context_stack.push_back(unique_id()); + context_stack_tags.push_back(tag); + return s.str(); + } + string tag(const string &tag, const string &cls, const string &body, int id = -1) { + ostringstream s; + s << open_tag(tag, cls, id); + s << body; + s << close_tag(tag); + return s.str(); + } + string close_tag(const string &tag) { + internal_assert(!context_stack.empty() && tag == context_stack_tags.back()); + context_stack.pop_back(); + context_stack_tags.pop_back(); + return ""; + } + + StmtHierarchyInfo get_stmt_hierarchy_html(const Stmt &op) { + StmtHierarchyInfo stmt_hierarchy_info = get_stmt_hierarchy.get_hierarchy_html(op); + string &html = stmt_hierarchy_info.html; + string popup = generate_stmt_hierarchy_popup(html); + stmt_hierarchy_info.html = popup; + + return stmt_hierarchy_info; + } + StmtHierarchyInfo get_stmt_hierarchy_html(const Expr &op) { + StmtHierarchyInfo stmt_hierarchy_info = get_stmt_hierarchy.get_hierarchy_html(op); + string &html = stmt_hierarchy_info.html; + string popup = generate_stmt_hierarchy_popup(html); + stmt_hierarchy_info.html = popup; + + return stmt_hierarchy_info; + } + + string generate_stmt_hierarchy_popup(const string &hierarchy_HTML) { + ostringstream popup; + + popup_count++; + popup << "\n"; + + return popup.str(); + } + + string open_cost_span(const Stmt &stmt_op) { + StmtHierarchyInfo stmt_hierarchy_info = get_stmt_hierarchy_html(stmt_op); + + ostringstream s; + + s << cost_colors(stmt_op.get(), stmt_hierarchy_info); + + // popup window - will put them all at the end + popups += stmt_hierarchy_info.html + "\n"; + + s << ""; + return s.str(); + } + string open_cost_span(const Expr &stmt_op) { + StmtHierarchyInfo stmt_hierarchy_info = get_stmt_hierarchy_html(stmt_op); + + ostringstream s; + + s << cost_colors(stmt_op.get(), stmt_hierarchy_info); + + // popup window - will put them all at the end + popups += stmt_hierarchy_info.html + "\n"; + + s << ""; + return s.str(); + } + + string close_cost_span() { + return ""; + } + string open_cost_span_else_case(Stmt else_case) { + Stmt new_node = + IfThenElse::make(Variable::make(Int(32), StmtToViz_canIgnoreVariableName_string), std::move(else_case), nullptr); + + StmtHierarchyInfo stmt_hierarchy_info = get_stmt_hierarchy.get_else_hierarchy_html(); + string popup = generate_stmt_hierarchy_popup(stmt_hierarchy_info.html); + + // popup window - will put them all at the end + popups += popup + "\n"; + + ostringstream s; + + curr_line_num += 1; + + s << ""; + + s << computation_button(new_node.get(), stmt_hierarchy_info); + s << data_movement_button(new_node.get(), stmt_hierarchy_info); + + s << ""; + + s << ""; + return s.str(); + } + + string open_span(const string &cls, int id = -1) { + return open_tag("span", cls, id); + } + string close_span() { + return close_tag("span"); + } + string span(const string &cls, const string &body, int id = -1) { + return tag("span", cls, body, id); + } + string matched(const string &cls, const string &body, int id = -1) { + return span(cls + " Matched", body, id); + } + string matched(const string &body) { + return span("Matched", body); + } + + string color_button(const IRNode *op, bool is_computation, + const StmtHierarchyInfo &stmt_hierarchy_info) { + + int color_range_inclusive, color_range_exclusive; + + if (is_computation) { + color_range_inclusive = ir_visualization.get_combined_color_range(op, true); + color_range_exclusive = ir_visualization.get_color_range(op, StmtCostModel::Compute); + } else { + color_range_inclusive = ir_visualization.get_combined_color_range(op, false); + color_range_exclusive = ir_visualization.get_color_range(op, StmtCostModel::DataMovement); + } + tooltip_count++; + + ostringstream s; + s << ""; + + return s.str(); + } + + string computation_button(const IRNode *op, const StmtHierarchyInfo &stmt_hierarchy_info) { + ostringstream s; + s << color_button(op, true, stmt_hierarchy_info); + + string tooltip_text = + ir_visualization.generate_computation_cost_tooltip(op, "[Click to see full hierarchy]"); + + // tooltip span + s << ""; + s << tooltip_text; + s << ""; + + return s.str(); + } + string data_movement_button(const IRNode *op, const StmtHierarchyInfo &stmt_hierarchy_info) { + ostringstream s; + s << color_button(op, false, stmt_hierarchy_info); + + string tooltip_text = ir_visualization.generate_data_movement_cost_tooltip( + op, "[Click to see full hierarchy]"); + + // tooltip span + s << ""; + s << tooltip_text; + s << ""; + + return s.str(); + } + string cost_colors(const IRNode *op, const StmtHierarchyInfo &stmt_hierarchy_info) { + curr_line_num += 1; + + ostringstream s; + + if (op->node_type == IRNodeType::Allocate || op->node_type == IRNodeType::Evaluate || + op->node_type == IRNodeType::IfThenElse || op->node_type == IRNodeType::For || + op->node_type == IRNodeType::ProducerConsumer) { + s << ""; + } else { + s << ""; + } + + s << computation_button(op, stmt_hierarchy_info); + s << data_movement_button(op, stmt_hierarchy_info); + + s << ""; + + return s.str(); + } + + string open_div(const string &cls, int id = -1) { + return open_tag("div", cls, id) + "\n"; + } + string close_div() { + return close_tag("div") + "\n"; + } + + string open_anchor(const string &anchor_name) { + return ""; + } + string close_anchor() { + return ""; + } + + string see_viz_button(const string &anchor_name) { + ostringstream s; + + s << ""; + + return s.str(); + } + + string see_assembly_button(const int &assembly_line_num_start, + const int &assembly_line_num_end = -1) { + ostringstream s; + + // Generates the "code-square" icon from Boostrap: + // https://icons.getbootstrap.com/icons/code-square/ + tooltip_count++; + s << ""; + + // tooltip span + s << ""; + s << "Click to see assembly code"; + s << ""; + + return s.str(); + } + + string open_line() { + return "

    "; + } + string close_line() { + return "

    "; + } + + string keyword(const string &x) { + return span("Keyword", x); + } + string type(const string &x) { + return span("Type", x); + } + string symbol(const string &x) { + return span("Symbol", x); + } + + Scope scope; + string var(const string &x) { + int id; + if (scope.contains(x)) { + id = scope.get(x); + } else { + id = unique_id(); + scope.push(x, id); + } + + ostringstream s; + s << ""; + s << x; + s << ""; + return s.str(); + } + + void print_list(const std::vector &args) { + for (size_t i = 0; i < args.size(); i++) { + if (i > 0) { + stream << matched(",") << " "; + } + print(args[i]); + } + } + void print_list(const string &l, const std::vector &args, const string &r) { + stream << matched(l); + print_list(args); + stream << matched(r); + } + + string open_expand_button(int id) { + ostringstream button; + button << "" + << "
    " + << "" + << "
    " + << "
    "; + return button.str(); + } + + string close_expand_button() { + return "
    "; + } + + void visit(const IntImm *op) override { + stream << open_span("IntImm Imm"); + stream << Expr(op); + stream << close_span(); + } + + void visit(const UIntImm *op) override { + stream << open_span("UIntImm Imm"); + stream << Expr(op); + stream << close_span(); + } + + void visit(const FloatImm *op) override { + stream << open_span("FloatImm Imm"); + stream << Expr(op); + stream << close_span(); + } + + void visit(const StringImm *op) override { + stream << open_span("StringImm"); + stream << "\""; + for (auto c : op->value) { + if (c >= ' ' && c <= '~' && c != '\\' && c != '"') { + stream << c; + } else { + stream << "\\"; + switch (c) { + case '"': + stream << "\""; + break; + case '\\': + stream << "\\"; + break; + case '\t': + stream << "t"; + break; + case '\r': + stream << "r"; + break; + case '\n': + stream << "n"; + break; + default: + const char *hex_digits = "0123456789ABCDEF"; + stream << "x" << hex_digits[c >> 4] << hex_digits[c & 0xf]; + } + } + } + stream << "\"" << close_span(); + } + + void visit(const Variable *op) override { + + stream << var(op->name); + } + + void visit(const Cast *op) override { + stream << open_span("Cast"); + + stream << open_span("Matched"); + stream << open_span("Type") << op->type << close_span(); + stream << "("; + stream << close_span(); + print(op->value); + stream << matched(")"); + + stream << close_span(); + } + + void visit(const Reinterpret *op) override { + stream << open_span("Reinterpret"); + + stream << open_span("Matched"); + stream << open_span("Type") << op->type << close_span(); + stream << "("; + stream << close_span(); + print(op->value); + stream << matched(")"); + + stream << close_span(); + } + + void visit_binary_op(const Expr &a, const Expr &b, const char *op) { + stream << open_span("BinaryOp"); + + stream << matched("("); + print(a); + stream << " " << matched("Operator", op) << " "; + print(b); + stream << matched(")"); + + stream << close_span(); + } + + void visit(const Add *op) override { + visit_binary_op(op->a, op->b, "+"); + } + void visit(const Sub *op) override { + visit_binary_op(op->a, op->b, "-"); + } + void visit(const Mul *op) override { + visit_binary_op(op->a, op->b, "*"); + } + void visit(const Div *op) override { + visit_binary_op(op->a, op->b, "/"); + } + void visit(const Mod *op) override { + visit_binary_op(op->a, op->b, "%"); + } + void visit(const And *op) override { + visit_binary_op(op->a, op->b, "&&"); + } + void visit(const Or *op) override { + visit_binary_op(op->a, op->b, "||"); + } + void visit(const NE *op) override { + visit_binary_op(op->a, op->b, "!="); + } + void visit(const LT *op) override { + visit_binary_op(op->a, op->b, "<"); + } + void visit(const LE *op) override { + visit_binary_op(op->a, op->b, "<="); + } + void visit(const GT *op) override { + visit_binary_op(op->a, op->b, ">"); + } + void visit(const GE *op) override { + visit_binary_op(op->a, op->b, ">="); + } + void visit(const EQ *op) override { + visit_binary_op(op->a, op->b, "=="); + } + + void visit(const Min *op) override { + stream << open_span("Min"); + print_list(symbol("min") + "(", {op->a, op->b}, ")"); + stream << close_span(); + } + void visit(const Max *op) override { + stream << open_span("Max"); + print_list(symbol("max") + "(", {op->a, op->b}, ")"); + stream << close_span(); + } + void visit(const Not *op) override { + stream << open_span("Not"); + stream << "!"; + print(op->a); + stream << close_span(); + } + void visit(const Select *op) override { + stream << open_span("Select"); + print_list(symbol("select") + "(", {op->condition, op->true_value, op->false_value}, ")"); + stream << close_span(); + } + void visit(const Load *op) override { + stream << open_span("Load"); + stream << open_span("Matched"); + stream << var(op->name) << "["; + stream << close_span(); + print(op->index); + stream << matched("]"); + if (!is_const_one(op->predicate)) { + stream << " " << keyword("if") << " "; + print(op->predicate); + } + stream << close_span(); + } + void visit(const Ramp *op) override { + stream << open_span("Ramp"); + print_list(symbol("ramp") + "(", {op->base, op->stride, Expr(op->lanes)}, ")"); + stream << close_span(); + } + void visit(const Broadcast *op) override { + stream << open_span("Broadcast"); + stream << open_span("Matched"); + stream << symbol("x") << op->lanes << "("; + stream << close_span(); + print(op->value); + stream << matched(")"); + stream << close_span(); + } + void visit(const Call *op) override { + stream << open_span("Call"); + + print_list(symbol(op->name) + "(", op->args, ")"); + stream << close_span(); + } + + void visit(const Let *op) override { + + scope.push(op->name, unique_id()); + stream << open_span("Let"); + stream << open_span("Matched"); + stream << "(" << keyword("let") << " "; + stream << var(op->name); + stream << close_span(); + stream << " " << matched("Operator Assign", "=") << " "; + print(op->value); + + stream << " " << matched("Keyword", "in") << " "; + print(op->body); + stream << matched(")"); + stream << close_span(); + scope.pop(op->name); + } + void visit(const LetStmt *op) override { + + scope.push(op->name, unique_id()); + stream << open_div("LetStmt") << open_line(); + + stream << open_cost_span(op); + stream << open_span("Matched"); + stream << keyword("let") << " "; + stream << var(op->name); + stream << close_span(); + stream << " " << matched("Operator Assign", "=") << " "; + + print(op->value); + stream << close_cost_span(); + + stream << close_line(); + print(op->body); + stream << close_div(); + + scope.pop(op->name); + } + void visit(const AssertStmt *op) override { + stream << open_div("AssertStmt WrapLine"); + std::vector args; + args.push_back(op->condition); + args.push_back(op->message); + stream << open_cost_span(op); + print_list(symbol("assert") + "(", args, ")"); + stream << close_cost_span(); + stream << close_div(); + } + void visit(const ProducerConsumer *op) override { + scope.push(op->name, unique_id()); + stream << open_div(op->is_producer ? "Produce" : "Consumer"); + + // anchoring + producer_consumer_count++; + string anchor_name = "producerConsumer" + std::to_string(producer_consumer_count); + + // for assembly + int assembly_line_num = get_assembly_info_viz.get_line_number_prod_cons(op); + + int produce_id = unique_id(); + + stream << open_cost_span(op); + stream << open_span("Matched"); + stream << open_expand_button(produce_id); + stream << open_anchor(anchor_name); + stream << keyword(op->is_producer ? "produce" : "consume") << " "; + stream << var(op->name); + stream << close_expand_button() << " {"; + stream << close_span(); + stream << close_anchor(); + stream << close_cost_span(); + if (assembly_line_num != -1) { + stream << see_assembly_button(assembly_line_num); + } + stream << see_viz_button(anchor_name); + + stream << open_div(op->is_producer ? "ProduceBody Indent" : "ConsumeBody Indent", + produce_id); + print(op->body); + stream << close_div(); + stream << open_div("ClosingBrace"); + stream << matched("}"); + stream << close_div(); + stream << close_div(); + scope.pop(op->name); + } + + void visit(const For *op) override { + + scope.push(op->name, unique_id()); + stream << open_div("For"); + + // anchoring + for_count++; + string anchor_name = "for" + std::to_string(for_count); + + // for assembly + ForLoopLineNumber assembly_line_info = get_assembly_info_viz.get_line_numbers_for_loops(op); + int assembly_line_num_start = assembly_line_info.start_line; + int assembly_line_num_end = assembly_line_info.end_line; + + int id = unique_id(); + stream << open_cost_span(op); + stream << open_expand_button(id); + stream << open_anchor(anchor_name); + stream << open_span("Matched"); + if (op->for_type == ForType::Serial) { + stream << keyword("for"); + } else if (op->for_type == ForType::Parallel) { + stream << keyword("parallel"); + } else if (op->for_type == ForType::Vectorized) { + stream << keyword("vectorized"); + } else if (op->for_type == ForType::Unrolled) { + stream << keyword("unrolled"); + } else if (op->for_type == ForType::GPUBlock) { + stream << keyword("gpu_block"); + } else if (op->for_type == ForType::GPUThread) { + stream << keyword("gpu_thread"); + } else if (op->for_type == ForType::GPULane) { + stream << keyword("gpu_lane"); + } else { + internal_assert(false) << "\n" + << "Unknown for type: " << ((int)op->for_type) << "\n\n"; + } + stream << " ("; + stream << close_span(); + + print_list({Variable::make(Int(32), op->name), op->min, op->extent}); + + stream << matched(")"); + stream << close_expand_button(); + stream << " " << matched("{"); + stream << close_anchor(); + stream << close_cost_span(); + if (assembly_line_num_start != -1) { + stream << see_assembly_button(assembly_line_num_start, assembly_line_num_end); + } + stream << see_viz_button(anchor_name); + + stream << open_div("ForBody Indent", id); + print(op->body); + stream << close_div(); + stream << open_div("ClosingBrace"); + stream << matched("}"); + stream << close_div(); + stream << close_div(); + scope.pop(op->name); + } + + void visit(const Acquire *op) override { + stream << open_div("Acquire"); + int id = unique_id(); + stream << open_span("Matched"); + stream << open_expand_button(id); + stream << keyword("acquire ("); + stream << close_span(); + print(op->semaphore); + stream << ", "; + print(op->count); + stream << matched(")"); + stream << close_expand_button() << " {"; + stream << open_div("Acquire Indent", id); + print(op->body); + stream << close_div(); + stream << open_div("ClosingBrace"); + stream << matched("}"); + stream << close_div(); + stream << close_div(); + } + + void visit(const Store *op) override { + stream << open_div("Store WrapLine"); + + // anchoring + store_count++; + string anchor_name = "store" + std::to_string(store_count); + + stream << open_cost_span(op); + stream << open_anchor(anchor_name); + + stream << open_span("Matched"); + stream << var(op->name) << "["; + stream << close_span(); + + print(op->index); + stream << matched("]"); + + stream << " " << span("Operator Assign Matched", "=") << " "; + + stream << open_span("StoreValue"); + print(op->value); + if (!is_const_one(op->predicate)) { + stream << " " << keyword("if") << " "; + print(op->predicate); + } + stream << close_span(); + + stream << close_anchor(); + stream << close_cost_span(); + stream << see_viz_button(anchor_name); + stream << close_div(); + } + void visit(const Provide *op) override { + stream << open_div("Provide WrapLine"); + stream << open_span("Matched"); + stream << var(op->name) << "("; + stream << close_span(); + print_list(op->args); + stream << matched(")") << " "; + stream << matched("=") << " "; + if (op->values.size() > 1) { + print_list("{", op->values, "}"); + } else { + print(op->values[0]); + } + stream << close_div(); + } + void visit(const Allocate *op) override { + scope.push(op->name, unique_id()); + stream << open_div("Allocate"); + + // anchoring + allocate_count++; + string anchor_name = "allocate" + std::to_string(allocate_count); + stream << open_anchor(anchor_name); + + stream << open_cost_span(op); + + stream << open_span("Matched"); + stream << keyword("allocate") << " "; + stream << var(op->name) << "["; + stream << close_span(); + + stream << open_span("Type"); + stream << op->type; + stream << close_span(); + + for (const auto &extent : op->extents) { + stream << " * "; + print(extent); + } + stream << matched("]"); + if (!is_const_one(op->condition)) { + stream << " " << keyword("if") << " "; + print(op->condition); + } + if (op->new_expr.defined()) { + stream << open_span("Matched"); + stream << keyword("custom_new") << "{"; + print(op->new_expr); + stream << open_div("ClosingBrace"); + stream << matched("}"); + stream << close_div(); + } + if (!op->free_function.empty()) { + stream << open_span("Matched"); + stream << keyword("custom_delete") << "{ " << op->free_function << "(); "; + stream << open_div("ClosingBrace"); + stream << matched("}"); + stream << close_div(); + } + stream << close_cost_span(); + + stream << close_anchor(); + stream << see_viz_button(anchor_name); + + stream << open_div("AllocateBody"); + print(op->body); + stream << close_div(); + stream << close_div(); + scope.pop(op->name); + } + void visit(const Free *op) override { + stream << open_div("Free WrapLine"); + stream << open_cost_span(op); + stream << keyword("free") << " "; + stream << var(op->name); + stream << close_cost_span(); + stream << close_div(); + } + void visit(const Realize *op) override { + scope.push(op->name, unique_id()); + stream << open_div("Realize"); + int id = unique_id(); + stream << open_expand_button(id); + stream << keyword("realize") << " "; + stream << var(op->name); + stream << matched("("); + for (size_t i = 0; i < op->bounds.size(); i++) { + print_list("[", {op->bounds[i].min, op->bounds[i].extent}, "]"); + if (i < op->bounds.size() - 1) { + stream << ", "; + } + } + stream << matched(")"); + if (!is_const_one(op->condition)) { + stream << " " << keyword("if") << " "; + print(op->condition); + } + stream << close_expand_button(); + + stream << " " << matched("{"); + stream << open_div("RealizeBody Indent", id); + print(op->body); + stream << close_div(); + stream << open_div("ClosingBrace"); + stream << matched("}"); + stream << close_div(); + stream << close_div(); + scope.pop(op->name); + } + + void visit(const Prefetch *op) override { + stream << open_span("Prefetch"); + stream << keyword("prefetch") << " "; + stream << var(op->name); + stream << matched("("); + for (size_t i = 0; i < op->bounds.size(); i++) { + print_list("[", {op->bounds[i].min, op->bounds[i].extent}, "]"); + if (i < op->bounds.size() - 1) { + stream << ", "; + } + } + stream << matched(")"); + if (!is_const_one(op->condition)) { + stream << " " << keyword("if") << " "; + print(op->condition); + } + stream << close_span(); + + stream << open_div("PrefetchBody"); + print(op->body); + stream << close_div(); + } + + // To avoid generating ridiculously deep DOMs, we flatten blocks here. + void visit_block_stmt(const Stmt &stmt) { + if (const Block *b = stmt.as()) { + visit_block_stmt(b->first); + visit_block_stmt(b->rest); + } else if (stmt.defined()) { + print(stmt); + } + } + void visit(const Block *op) override { + stream << open_div("Block"); + visit_block_stmt(op->first); + visit_block_stmt(op->rest); + stream << close_div(); + } + + // We also flatten forks + void visit_fork_stmt(const Stmt &stmt) { + if (const Fork *f = stmt.as()) { + visit_fork_stmt(f->first); + visit_fork_stmt(f->rest); + } else if (stmt.defined()) { + stream << open_div("ForkTask"); + int id = unique_id(); + stream << open_expand_button(id); + stream << matched("task {"); + stream << close_expand_button(); + stream << open_div("ForkTask Indent", id); + print(stmt); + stream << close_div(); + stream << open_div("ClosingBrace"); + stream << matched("}"); + stream << close_div(); + stream << close_div(); + } + } + void visit(const Fork *op) override { + stream << open_div("Fork"); + int id = unique_id(); + stream << open_expand_button(id); + stream << keyword("fork") << " " << matched("{"); + stream << close_expand_button(); + stream << open_div("Fork Indent", id); + visit_fork_stmt(op->first); + visit_fork_stmt(op->rest); + stream << close_div(); + stream << open_div("ClosingBrace"); + stream << matched("}"); + stream << close_div(); + stream << close_div(); + } + + void visit(const IfThenElse *op) override { + stream << open_div("IfThenElse"); + + // anchoring + if_count++; + string anchor_name = "if" + std::to_string(if_count); + + int id = unique_id(); + stream << open_cost_span(op); + stream << open_expand_button(id); + stream << open_anchor(anchor_name); + stream << open_span("Matched"); + + // for line numbers + stream << open_span("IfSpan"); + stream << close_span(); + + stream << keyword("if") << " ("; + stream << close_span(); + + while (true) { + print(op->condition); + stream << matched(")"); + stream << close_expand_button() << " "; + stream << matched("{"); + stream << close_anchor(); + stream << close_cost_span(); + stream << see_viz_button(anchor_name); + + stream << open_div("ThenBody Indent", id); + print(op->then_case); + stream << close_div(); // close thenbody div + + if (!op->else_case.defined()) { + stream << open_div("ClosingBrace"); + stream << matched("}"); + stream << close_div(); + break; + } + + id = unique_id(); + + if (const IfThenElse *nested_if = op->else_case.as()) { + stream << open_div("ClosingBrace"); + stream << matched("}"); + stream << close_div(); + + stream << open_cost_span(nested_if); + stream << open_expand_button(id); + stream << open_span("Matched"); + + // for line numbers + stream << open_span("IfSpan"); + stream << close_span(); + + // anchoring + if_count++; + string anchor_name = "if" + std::to_string(if_count); + stream << open_anchor(anchor_name); + + stream << keyword("else if") << " ("; + stream << close_span(); + op = nested_if; + } else { + stream << open_div("ClosingBrace"); + stream << matched("}"); + stream << close_div(); + + stream << open_cost_span_else_case(op->else_case); + stream << open_expand_button(id); + + // for line numbers + stream << open_span("IfSpan"); + stream << close_span(); + + // anchoring + if_count++; + string anchor_name = "if" + std::to_string(if_count); + stream << open_anchor(anchor_name); + + stream << keyword("else "); + stream << close_expand_button() << "{"; + stream << close_anchor(); + stream << close_cost_span(); + stream << see_viz_button(anchor_name); + + stream << open_div("ElseBody Indent", id); + print(op->else_case); + stream << close_div(); + stream << open_div("ClosingBrace"); + stream << matched("}"); + stream << close_div(); + break; + } + } + stream << close_div(); // Closing ifthenelse div. + } + + void visit(const Evaluate *op) override { + stream << open_div("Evaluate"); + stream << open_cost_span(op); + print(op->value); + stream << close_cost_span(); + + stream << close_div(); + } + + void visit(const Shuffle *op) override { + stream << open_span("Shuffle"); + if (op->is_concat()) { + print_list(symbol("concat_vectors("), op->vectors, ")"); + } else if (op->is_interleave()) { + print_list(symbol("interleave_vectors("), op->vectors, ")"); + } else if (op->is_extract_element()) { + std::vector args = op->vectors; + args.emplace_back(op->slice_begin()); + print_list(symbol("extract_element("), args, ")"); + } else if (op->is_slice()) { + std::vector args = op->vectors; + args.emplace_back(op->slice_begin()); + args.emplace_back(op->slice_stride()); + args.emplace_back(static_cast(op->indices.size())); + print_list(symbol("slice_vectors("), args, ")"); + } else { + std::vector args = op->vectors; + for (int i : op->indices) { + args.emplace_back(i); + } + print_list(symbol("shuffle("), args, ")"); + } + stream << close_span(); + } + + void visit(const VectorReduce *op) override { + stream << open_span("VectorReduce"); + stream << open_span("Type") << op->type << close_span(); + print_list(symbol("vector_reduce") + "(", {op->op, op->value}, ")"); + stream << close_span(); + } + + void visit(const Atomic *op) override { + stream << open_div("Atomic"); + int id = unique_id(); + stream << open_expand_button(id); + stream << open_span("Matched"); + if (op->mutex_name.empty()) { + stream << keyword("atomic") << matched("{"); + } else { + stream << keyword("atomic") << " ("; + stream << symbol(op->mutex_name); + stream << ")" << matched("{"); + } + stream << close_span(); + stream << open_div("Atomic Body Indent", id); + print(op->body); + stream << close_div(); + stream << open_div("ClosingBrace"); + stream << matched("}"); + stream << close_div(); + stream << close_div(); + } + +public: + FindStmtCost generate_costs(const Module &m) { + find_stmt_cost.generate_costs(m); + return find_stmt_cost; + } + + string generate_ir_visualization(const Module &m) { + return ir_visualization.generate_ir_visualization_html(m); + } + + void print(const Expr &ir) { + ir.accept(this); + } + + void print(const Stmt &ir) { + ir.accept(this); + } + + void print(const LoweredFunc &op) { + scope.push(op.name, unique_id()); + stream << open_div("Function"); + + // anchoring + functionCount++; + string anchor_name = "loweredFunc" + std::to_string(functionCount); + + int id = unique_id(); + stream << open_expand_button(id); + stream << open_anchor(anchor_name); + stream << open_span("Matched"); + stream << keyword("func"); + stream << " " << op.name << "("; + stream << close_span(); + for (size_t i = 0; i < op.args.size(); i++) { + if (i > 0) { + stream << matched(",") << " "; + } + stream << var(op.args[i].name); + } + stream << matched(")"); + stream << close_anchor(); + stream << close_expand_button(); + stream << " " << matched("{"); + stream << see_viz_button(anchor_name); + + stream << open_div("FunctionBody Indent", id); + + print(op.body); + + stream << close_div(); + stream << open_div("ClosingBrace"); + stream << matched("}"); + stream << close_div(); + + stream << close_div(); + scope.pop(op.name); + } + + void print_cuda_gpu_source_kernels(const string &str) { + std::istringstream ss(str); + int current_id = -1; + stream << ""; + bool in_braces = false; + bool in_func_signature = false; + string current_kernel; + for (string line; std::getline(ss, line);) { + if (line.empty()) { + stream << "\n"; + continue; + } + line = replace_all(line, "&", "&"); + line = replace_all(line, "<", "<"); + line = replace_all(line, ">", ">"); + line = replace_all(line, "\"", """); + line = replace_all(line, "/", "/"); + line = replace_all(line, "'", "'"); + + if (starts_with(line, ".visible .entry")) { + std::vector parts = split_string(line, " "); + if (parts.size() == 3) { + in_func_signature = true; + current_id = unique_id(); + stream << open_expand_button(current_id); + + string kernel_name = parts[2].substr(0, parts[2].length() - 1); + line = keyword(".visible") + " " + keyword(".entry") + " "; + line += var(kernel_name) + " " + matched("("); + current_kernel = kernel_name; + } + } else if (starts_with(line, ")") && in_func_signature) { + stream << close_expand_button(); + in_func_signature = false; + line = matched(")") + line.substr(1); + } else if (starts_with(line, "{") && !in_braces) { + in_braces = true; + stream << matched("{"); + stream << close_expand_button(); + internal_assert(current_id != -1); + stream << open_div("Indent", current_id); + current_id = -1; + line = line.substr(1); + scope.push(current_kernel, unique_id()); + } else if (starts_with(line, "}") && in_braces) { + stream << close_div(); + line = matched("}") + line.substr(1); + in_braces = false; + scope.pop(current_kernel); + } + + bool indent = false; + + if (line[0] == '\t') { + // Replace first tab with four spaces. + line = line.substr(1); + indent = true; + } + + line = replace_all(line, ".f32", ".f32"); + line = replace_all(line, ".f64", ".f64"); + + line = replace_all(line, ".s8", ".s8"); + line = replace_all(line, ".s16", ".s16"); + line = replace_all(line, ".s32", ".s32"); + line = replace_all(line, ".s64", ".s64"); + + line = replace_all(line, ".u8", ".u8"); + line = replace_all(line, ".u16", ".u16"); + line = replace_all(line, ".u32", ".u32"); + line = replace_all(line, ".u64", ".u64"); + + line = replace_all(line, ".b8", ".b8"); + line = replace_all(line, ".b16", ".b16"); + line = replace_all(line, ".b32", ".b32"); + line = replace_all(line, ".b64", ".b64"); + + line = replace_all(line, ".v2", ".v2"); + line = replace_all(line, ".v4", ".v4"); + + line = replace_all(line, "ld.", "ld."); + line = replace_all(line, "st.", "st."); + + size_t idx; + if ((idx = line.find("//")) != string::npos) { + line.insert(idx, ""); + line += ""; + } + + // Predicated instructions + if (line.front() == '@' && indent) { + idx = line.find(' '); + string pred = line.substr(1, idx - 1); + line = "@" + var(pred) + "" + line.substr(idx); + } + + // Labels + if (line.front() == 'L' && !indent && (idx = line.find(':')) != string::npos) { + string label = line.substr(0, idx); + line = "" + var(label) + ":" + line.substr(idx + 1); + } + + // Highlight operands + if ((idx = line.find(" \t")) != string::npos && line.back() == ';') { + string operands_str = line.substr(idx + 2); + operands_str = operands_str.substr(0, operands_str.length() - 1); + std::vector operands = split_string(operands_str, ", "); + operands_str = ""; + for (size_t opidx = 0; opidx < operands.size(); ++opidx) { + string op = operands[opidx]; + internal_assert(!op.empty()); + if (opidx != 0) { + operands_str += ", "; + } + if (op.back() == '}') { + string reg = op.substr(0, op.size() - 1); + operands_str += var(reg) + '}'; + } else if (op.front() == '%') { + operands_str += var(op); + } else if (op.find_first_not_of("-0123456789") == string::npos) { + operands_str += open_span("IntImm Imm"); + operands_str += op; + operands_str += close_span(); + } else if (starts_with(op, "0f") && + op.find_first_not_of("0123456789ABCDEF", 2) == string::npos) { + operands_str += open_span("FloatImm Imm"); + operands_str += op; + operands_str += close_span(); + } else if (op.front() == '[' && op.back() == ']') { + size_t idx = op.find('+'); + if (idx == string::npos) { + string reg = op.substr(1, op.size() - 2); + operands_str += '[' + var(reg) + ']'; + } else { + string reg = op.substr(1, idx - 1); + string offset = op.substr(idx + 1); + offset = offset.substr(0, offset.size() - 1); + operands_str += '[' + var(reg) + "+"; + operands_str += open_span("IntImm Imm"); + operands_str += offset; + operands_str += close_span(); + operands_str += ']'; + } + } else if (op.front() == '{') { + string reg = op.substr(1); + operands_str += '{' + var(reg); + } else if (op.front() == 'L') { + // Labels + operands_str += "" + var(op) + ""; + } else { + operands_str += op; + } + } + operands_str += ";"; + line = line.substr(0, idx + 2) + operands_str; + } + + if (indent) { + stream << " "; + } + stream << line << "\n"; + } + stream << ""; + } + + void print(const Buffer<> &op) { + bool include_data = ends_with(op.name(), "_gpu_source_kernels"); + int id = 0; + if (include_data) { + id = unique_id(); + stream << open_expand_button(id); + } + stream << open_div("Buffer<>"); + stream << keyword("buffer ") << var(op.name()); + if (include_data) { + stream << " = "; + stream << matched("{"); + stream << close_expand_button(); + stream << open_div("BufferData Indent", id); + string str((const char *)op.data(), op.size_in_bytes()); + if (starts_with(op.name(), "cuda_")) { + print_cuda_gpu_source_kernels(str); + } else { + stream << "
    \n";
    +                stream << str;
    +                stream << "
    \n"; + } + stream << close_div(); + + stream << " "; + internal_assert(false) << "\n\n\nvoid print(const Buffer<> &op): look at this line!!! make " + "sure the closing brace is correct! \n\n\n"; + stream << open_div("ClosingBrace"); + stream << matched("}"); + stream << close_div(); + } + stream << close_div(); + } + + void print(const Module &m) { + scope.push(m.name(), unique_id()); + + // doesn't currently support submodules - could comment out error, no guarantee it'll work + // as expected + for (const auto &s : m.submodules()) { + internal_assert(false) << "\n\nStmtToViz does not support submodules yet\n\n"; + print(s); + } + + int id = unique_id(); + stream << open_expand_button(id); + stream << open_div("Module"); + stream << open_span("Matched"); + stream << keyword("module") << " name=" << m.name() + << ", target=" << m.target().to_string(); + stream << close_span(); + stream << close_expand_button(); + stream << " " << matched("{"); + + stream << open_div("ModuleBody Indent", id); + + for (const auto &b : m.buffers()) { + print(b); + } + + // print main function first + for (const auto &f : m.functions()) { + if (f.name == m.name()) { + print(f); + } + } + + // print the rest of the functions + for (const auto &f : m.functions()) { + if (f.name != m.name()) { + print(f); + } + } + + stream << close_div(); + stream << open_div("ClosingBrace"); + stream << matched("}"); + stream << close_div(); + stream << close_div(); + scope.pop(m.name()); + } + + void start_stream(const string &filename) { + stream.open(filename.c_str()); + stream << ""; + + // bootstrap links + stream << "\n"; + stream << "\n"; + stream + << "\n"; + stream + << "\n"; + stream << "\n"; + stream << "\n"; + + // tooltip links + stream << "\n"; + stream << ""; + stream << ""; + stream << "\n"; + + // hierarchy links + stream << "\n"; + stream + << "\n"; + stream << "\n"; + + // expand button links + stream << "\n"; + stream << "\n"; + stream << "\n"; + stream << "\n"; + + // assembly code links + stream << "\n"; + stream << "\n"; + stream << "\n"; + stream << "\n"; + stream << "\n"; + stream << "\n"; + stream << "\n"; + + stream << "\n\n"; + stream << "\n"; + stream << "\n"; + } + + void end_stream() { + stream << "
    \n"; + stream << popups; + stream << "
    \n"; + + stream << "\n"; + stream << ""; + } + + string information_popup() { + + ostringstream popup; + + popup_count++; + popup << "\n"; + popup << "\n"; + + return popup.str(); + } + + string information_bar(const Module &m) { + popups += information_popup(); + + ostringstream info_bar_ss; + + info_bar_ss << "
    \n" + << "
    \n" + << "

    " << m.name() << "

    \n" + << "
    \n" + << "
    \n" + << "
    \n" + << "

    \n" + << "

    \n" + << "
    \n" + << "
    \n"; + + return info_bar_ss.str(); + } + + string resize_bar() { + ostringstream resize_bar_ss; + + resize_bar_ss << "
    \n" + << "
    \n" + << "
    \n" + << "" + << "
    \n" + << "
    \n" + << "" + << "
    \n" + << "
    \n" + << "
    \n"; + + return resize_bar_ss.str(); + } + + string resize_bar_assembly() { + ostringstream resize_bar_ss; + + resize_bar_ss << "
    \n" + << "
    \n" + << "
    \n" + << "" + << "
    \n" + << "
    \n" + << "" + << "
    \n" + << "
    \n" + << "
    \n"; + + return resize_bar_ss.str(); + } + + void generate_html(const string &filename, const Module &m) { + get_assembly_info_viz.generate_assembly_information(m, filename); + + // opening parts of the html + start_stream(filename); + + stream << "
    \n"; + + stream << information_bar(m); + + stream << "
    \n"; + + // print main html page + stream << "
    \n"; + print(m); + stream << "
    \n"; + + // for resizing the code and visualization divs + stream << resize_bar(); + + stream << "
    \n"; + stream << generate_ir_visualization(m); + stream << "
    \n"; + + // for resizing the visualization and assembly code divs + stream << resize_bar_assembly(); + + // assembly content + stream << "
    \n"; + stream << "
    \n"; + + stream << "
    \n"; // close mainContent div + stream << "
    \n"; // close outerDiv div + + // put assembly code in an invisible div + stream << get_assembly_info_viz.get_assembly_html(); + + // closing parts of the html + end_stream(); + } + + StmtToViz(const string &filename, const Module &m) + : id_count(0), get_stmt_hierarchy(generate_costs(m)), ir_visualization(find_stmt_cost), + if_count(0), producer_consumer_count(0), for_count(0), store_count(0), allocate_count(0), + functionCount(0), tooltip_count(0), popup_count(0), context_stack(1, 0) { + } + + string generate_tooltip_JS(int &tooltip_count) { + ostringstream tooltip_JS; + tooltip_JS << "\n// Tooltip JS\n" + << "function update(buttonElement, tooltipElement) { \n" + << " window.FloatingUIDOM.computePosition(buttonElement, tooltipElement, { \n" + << " placement: 'top', \n" + << " middleware: [ \n" + << " window.FloatingUIDOM.offset(6), \n" + << " window.FloatingUIDOM.flip(), \n" + << " window.FloatingUIDOM.shift({ padding: 5 }), \n" + << " ], \n" + << " }).then(({ x, y, placement, middlewareData }) => { \n" + << " Object.assign(tooltipElement.style, { \n" + << " left: `${x}px`, \n" + << " top: `${y}px`, \n" + << " }); \n" + << " // Accessing the data \n" + << " const staticSide = { \n" + << " top: 'bottom', \n" + << " right: 'left', \n" + << " bottom: 'top', \n" + << " left: 'right', \n" + << " }[placement.split('-')[0]]; \n" + << " }); \n" + << "} \n" + << "function showTooltip(buttonElement, tooltipElement) { \n" + << " tooltipElement.style.display = 'block'; \n" + << " tooltipElement.style.opacity = '1'; \n" + << " update(buttonElement, tooltipElement); \n" + << "} \n" + << "function hideTooltip(tooltipElement) { \n" + << " tooltipElement.style.display = ''; \n" + << " tooltipElement.style.opacity = '0'; \n" + << "} \n" + << "for (let i = 1; i <= " << tooltip_count << "; i++) { \n" + << " const button = document.getElementById('button' + i); \n" + << " const tooltip = document.getElementById('tooltip' + i); \n" + << " if (!button) { \n" + << " console.log('button' + i + ' not found'); \n" + << " } \n" + << " button.addEventListener('mouseenter', () => { \n" + << " showTooltip(button, tooltip); \n" + << " }); \n" + << " button.addEventListener('mouseleave', () => { \n" + << " hideTooltip(tooltip); \n" + << " } \n" + << " ); \n" + << " tooltip.addEventListener('focus', () => { \n" + << " showTooltip(button, tooltip); \n" + << " } \n" + << " ); \n" + << " tooltip.addEventListener('blur', () => { \n" + << " hideTooltip(tooltip); \n" + << " } \n" + << " ); \n" + << "} \n"; + + return tooltip_JS.str(); + } +}; + +const string StmtToViz::ir_code_css = "\n \ +/* Normal CSS */\n \ +body { font-family: Consolas, 'Liberation Mono', Menlo, Courier, monospace; font-size: 12px; background: #f8f8f8; margin-left:15px; } \n \ +a, a:hover, a:visited, a:active { color: inherit; text-decoration: none; } \n \ +b { font-weight: normal; }\n \ +p.WrapLine { margin: 0px; margin-left: 30px; text-indent:-30px; } \n \ +div.WrapLine { margin-left: 30px; text-indent:-30px; } \n \ +div.Indent { padding-left: 15px; }\n \ +div.ShowHide { position:absolute; left:-12px; width:12px; height:12px; } \n \ +span.Comment { color: #998; font-style: italic; }\n \ +span.Keyword { color: #333; font-weight: bold; }\n \ +span.Assign { color: #d14; font-weight: bold; }\n \ +span.Symbol { color: #990073; }\n \ +span.Type { color: #445588; font-weight: bold; }\n \ +span.StringImm { color: #d14; }\n \ +span.IntImm { color: #099; }\n \ +span.FloatImm { color: #099; }\n \ +b.Highlight { font-weight: bold; background-color: #DDD; }\n \ +span.Highlight { font-weight: bold; background-color: #FF0; }\n \ +span.OpF32 { color: hsl(106deg 100% 40%); font-weight: bold; }\n \ +span.OpF64 { color: hsl(106deg 100% 30%); font-weight: bold; }\n \ +span.OpB8 { color: hsl(208deg 100% 80%); font-weight: bold; }\n \ +span.OpB16 { color: hsl(208deg 100% 70%); font-weight: bold; }\n \ +span.OpB32 { color: hsl(208deg 100% 60%); font-weight: bold; }\n \ +span.OpB64 { color: hsl(208deg 100% 50%); font-weight: bold; }\n \ +span.OpI8 { color: hsl( 46deg 100% 45%); font-weight: bold; }\n \ +span.OpI16 { color: hsl( 46deg 100% 40%); font-weight: bold; }\n \ +span.OpI32 { color: hsl( 46deg 100% 34%); font-weight: bold; }\n \ +span.OpI64 { color: hsl( 46deg 100% 27%); font-weight: bold; }\n \ +span.OpVec2 { background-color: hsl(100deg 100% 90%); font-weight: bold; }\n \ +span.OpVec4 { background-color: hsl(100deg 100% 80%); font-weight: bold; }\n \ +span.Memory { color: #d22; font-weight: bold; }\n \ +span.Pred { background-color: #ffe8bd; font-weight: bold; }\n \ +span.Label { background-color: #bde4ff; font-weight: bold; }\n \ +code.ptx { tab-size: 26; white-space: pre; }\n \ +.tf-tree { overflow: unset; }\n \ +"; + +const string StmtToViz::code_viz_css = "\n \ +/* Additional Code Visualization CSS */\n \ +span.ButtonSpacer { width: 5px; color: transparent; display: inline-block; }\n \ +.infoButton { \n \ + background-color: rgba(113, 113, 113, 0.1); \n \ + border: 1px solid rgb(113, 113, 113); \n \ + color: rgb(113, 113, 113); \n \ + border-radius: 8px; \n \ + box-shadow: rgba(213, 217, 217, .5) 0 2px 5px 0; \n \ + box-sizing: border-box; \n \ + text-align: center; \n \ + vertical-align: middle; \n \ + margin-left: 5px; \n \ + margin-right: 5px; \n \ + font-size: 15px; \n \ +} \n \ +.infoButton:hover { \n \ + background-color: #f7fafa; \n \ +} \n \ +.colorButton { \n \ + height: 15px; \n \ + width: 5px; \n \ + margin-right: 2px; \n \ + border: 1px solid rgba(0, 0, 0, 0); \n \ + vertical-align: middle; \n \ + border-radius: 2px; \n \ +} \n \ +.colorButton:hover { \n \ + border: 1px solid grey; \n \ +} \n \ +.iconButton { \n \ + border: 0px; \n \ + background: transparent; \n \ + color: black; \n \ + font-size: 20px; \n \ + display: inline-block; \n \ + vertical-align: middle; \n \ + margin-right: 5px; \n \ + margin-left: 5px; \n \ +} \n \ +.iconButton:hover { \n \ + color: red; \n \ + background: transparent; \n \ +} \n \ +.resizeButton { \n \ + margin: 0px; \n \ +} \n \ +.assemblyIcon { \n \ + color: red; \n \ +} \n \ +.informationBarButton { \n \ + border: 0px; \n \ + background: transparent; \n \ + display: inline-block; \n \ + vertical-align: middle; \n \ + margin-right: 5px; \n \ + margin-top: 5px; \n \ +} \n \ +.informationBarButton:hover { \n \ + color: blue; \n \ +} \n \ +.assemblyIcon { \n \ + color: red; \n \ +} \n \ +"; + +const string StmtToViz::cost_colors_css = "\n \ +/* Cost Colors CSS */\n \ +span.CostColor19, div.CostColor19, .CostColor19 { background-color: rgb(130,31,27); } \n \ +span.CostColor18, div.CostColor18, .CostColor18 { background-color: rgb(145,33,30); } \n \ +span.CostColor17, div.CostColor17, .CostColor17 { background-color: rgb(160,33,32); } \n \ +span.CostColor16, div.CostColor16, .CostColor16 { background-color: rgb(176,34,34); } \n \ +span.CostColor15, div.CostColor15, .CostColor15 { background-color: rgb(185,47,32); } \n \ +span.CostColor14, div.CostColor14, .CostColor14 { background-color: rgb(193,59,30); } \n \ +span.CostColor13, div.CostColor13, .CostColor13 { background-color: rgb(202,71,27); } \n \ +span.CostColor12, div.CostColor12, .CostColor12 { background-color: rgb(210,82,22); } \n \ +span.CostColor11, div.CostColor11, .CostColor11 { background-color: rgb(218,93,16); } \n \ +span.CostColor10, div.CostColor10, .CostColor10 { background-color: rgb(226,104,6); } \n \ +span.CostColor9, div.CostColor9, .CostColor9 { background-color: rgb(229,118,9); } \n \ +span.CostColor8, div.CostColor8, .CostColor8 { background-color: rgb(230,132,15); } \n \ +span.CostColor7, div.CostColor7, .CostColor7 { background-color: rgb(231,146,20); } \n \ +span.CostColor6, div.CostColor6, .CostColor6 { background-color: rgb(232,159,25); } \n \ +span.CostColor5, div.CostColor5, .CostColor5 { background-color: rgb(233,172,30); } \n \ +span.CostColor4, div.CostColor4, .CostColor4 { background-color: rgb(233,185,35); } \n \ +span.CostColor3, div.CostColor3, .CostColor3 { background-color: rgb(233,198,40); } \n \ +span.CostColor2, div.CostColor2, .CostColor2 { background-color: rgb(232,211,45); } \n \ +span.CostColor1, div.CostColor1, .CostColor1 { background-color: rgb(231,223,50); } \n \ +span.CostColor0, div.CostColor0, .CostColor0 { background-color: rgb(236,233,89); } \n \ +span.CostColorSpacer { width: 2px; color: transparent; display: inline-block; }\n \ +span.CostComputation { width: 13px; display: inline-block; color: transparent; } \n \ +span.CostMovement { width: 13px; display: inline-block; color: transparent; } \n \ +span.smallColorIndent { position: absolute; left: 35px; } \n \ +span.bigColorIndent { position: absolute; left: 65px; } \n \ +"; + +const string StmtToViz::flexbox_div_css = "\n \ +/* Flexbox Div Styling CSS */ \n \ +div.outerDiv { \n \ + height: 100vh; \n \ + display: flex; \n \ + flex-direction: column; \n \ +} \n \ +div.informationBar { \n \ + display: flex; \n \ +} \n \ +div.mainContent { \n \ + display: flex; \n \ + flex-grow: 1; \n \ + width: 100%; \n \ + overflow: hidden; \n \ + border-top: 1px solid rgb(200,200,200) \n \ +} \n \ +div.IRCode-code { \n \ + counter-reset: line; \n \ + padding-left: 50px; \n \ + padding-top: 20px; \n \ + overflow-y: scroll; \n \ + position: relative; \n \ +} \n \ +div.IRVisualization { \n \ + overflow-y: scroll; \n \ + padding-top: 20px; \n \ + padding-left: 20px; \n \ + position: relative; \n \ +} \n \ +div.ResizeBar { \n \ + background: rgb(201, 231, 190); \n \ + cursor: col-resize; \n \ + border-left: 1px solid rgb(0, 0, 0); \n \ + border-right: 1px solid rgb(0, 0, 0); \n \ +} \n \ +div.collapseButtons { \n \ + position: relative; \n \ + top: 50%; \n \ +} \n \ +"; + +const string StmtToViz::line_numbers_css = "\n \ +/* Line Numbers CSS */\n \ +p.WrapLine,\n\ +div.WrapLine,\n\ +div.Consumer,\n\ +div.Produce,\n\ +div.For,\n\ +span.IfSpan,\n\ +div.Evaluate,\n\ +div.Allocate,\n\ +div.ClosingBrace,\n\ +div.Module,\n\ +div.Function {\n\ + counter-increment: line;\n\ +}\n\ +p.WrapLine:before,\n\ +div.WrapLine:before {\n\ + content: counter(line) '. ';\n\ + display: inline-block;\n\ + position: absolute;\n\ + left: 30px;\n\ + color: rgb(175, 175, 175);\n\ + user-select: none;\n\ + -webkit-user-select: none;\n\ +}\n\ +div.Consumer:before,\n\ +div.Produce:before,\n\ +div.For:before,\n\ +span.IfSpan:before,\n\ +div.Evaluate:before,\n\ +div.Allocate:before, \n\ +div.ClosingBrace:before,\n\ +div.Module:before, \n\ +div.Function:before {\n\ + content: counter(line) '. ';\n\ + display: inline-block;\n\ + position: absolute;\n\ + left: 0px;\n\ + color: rgb(175, 175, 175);\n\ + user-select: none;\n\ + -webkit-user-select: none;\n\ +}\n\ +"; + +const string StmtToViz::code_mirror_css = "\n \ +/* CodeMirror */ \n \ +.CodeMirror { \n \ + height: 100%; \n \ + width: 100%; \n \ +} \n \ +.styled-background { \n \ + background-color: #ff7; \n \ +} \n \ +"; + +const string StmtToViz::tooltip_css = "\n \ +/* Tooltip CSS */\n \ +.left-table { text-align: right; color: grey; vertical-align: middle; font-size: 12px; }\n \ +.right-table { text-align: left; vertical-align: middle; font-size: 12px; font-weight: bold; padding-left: 3px; }\n \ +.tooltipTable { border: 0px; margin-left: auto; margin-right: auto; } \n \ +.tooltip { \n \ + display: none; \n \ + position: absolute; \n \ + top: 0; \n \ + left: 0; \n \ + background: white; \n \ + padding: 5px; \n \ + font-size: 90%; \n \ + pointer-events: none; \n \ + border-radius: 5px; \n \ + border: 1px dashed #aaa; \n \ + z-index: 9999; \n \ + box-shadow: rgba(100, 100, 100, 0.8) 0 2px 5px 0; \n \ +} \n \ +.CostTooltip { \n \ + min-width: max-content; \n \ +} \n \ +.conditionTooltip { \n \ + width: 300px; \n \ + padding: 5px; \n \ + font-family: Consolas, 'Liberation Mono', Menlo, Courier, monospace; \n \ +} \n \ +span.tooltipHelperText { \n \ + color: red; \n \ + margin-top: 5px; \n \ +} \n \ +"; + +const string StmtToViz::ir_code_js = "\n \ +/* Expand/Collapse buttons */\n \ +function toggle(id, buttonId) { \n \ + e = document.getElementById(id); \n \ + show = document.getElementById(id + '-show'); \n \ + hide = document.getElementById(id + '-hide'); \n \ + button1 = document.getElementById('button' + buttonId); \n \ + button2 = document.getElementById('button' + (buttonId - 1)); \n \ + if (e.style.visibility != 'hidden') { \n \ + e.style.height = '0px'; \n \ + e.style.visibility = 'hidden'; \n \ + show.style.display = 'block'; \n \ + hide.style.display = 'none'; \n \ + // make inclusive \n \ + if (button1 != null && button2 != null) { \n \ + inclusiverange1 = button1.getAttribute('inclusiverange'); \n \ + newClassName = button1.className.replace(/CostColor\\d+/, 'CostColor' + inclusiverange1); \n \ + button1.className = newClassName; \n \ + inclusiverange2 = button2.getAttribute('inclusiverange'); \n \ + newClassName = button2.className.replace(/CostColor\\d+/, 'CostColor' + inclusiverange2); \n \ + button2.className = newClassName; \n \ + } \n \ + } else { \n \ + e.style = ''; \n \ + show.style.display = 'none'; \n \ + hide.style.display = 'block'; \n \ + // make exclusive \n \ + if (button1 != null && button2 != null) { \n \ + exclusiverange1 = button1.getAttribute('exclusiverange'); \n \ + newClassName = button1.className.replace(/CostColor\\d+/, 'CostColor' + exclusiverange1); \n \ + button1.className = newClassName; \n \ + exclusiverange2 = button2.getAttribute('exclusiverange'); \n \ + newClassName = button2.className.replace(/CostColor\\d+/, 'CostColor' + exclusiverange2); \n \ + button2.className = newClassName; \n \ + } \n \ + } \n \ + return false; \n \ +} \n \ +"; + +const string StmtToViz::scroll_to_function_code_to_viz_js = "\n \ +// scroll to function - code to viz \n \ +function makeVisibleViz(element) { \n \ + if (!element) return; \n \ + if (element.className == 'mainContent') return; \n \ + if (element.style.visibility == 'hidden') { \n \ + element.style = ''; \n \ + show = document.getElementById(element.id + '-show'); \n \ + hide = document.getElementById(element.id + '-hide'); \n \ + show.style.display = 'none'; \n \ + hide.style.display = 'block'; \n \ + return; \n \ + } \n \ + makeVisibleViz(element.parentNode); \n \ +} \n \ +function getOffsetTop(element) { \n \ + if (!element) return 0; \n \ + if (element.id == 'IRVisualization') return 0; \n \ + return getOffsetTop(element.offsetParent) + element.offsetTop; \n \ +} \n \ +function getOffsetLeft(element) { \n \ + if (!element) return 0; \n \ + if (element.id == 'IRVisualization') return 0; \n \ + return getOffsetLeft(element.offsetParent) + element.offsetLeft; \n \ +} \n \ +function scrollToFunctionCodeToViz(id) { \n \ + var container = document.getElementById('IRVisualization'); \n \ + var scrollToObject = document.getElementById(id); \n \ + makeVisibleViz(scrollToObject); \n \ + container.scrollTo({ \n \ + top: getOffsetTop(scrollToObject) - 20, \n \ + left: getOffsetLeft(scrollToObject) - 40, \n \ + behavior: 'smooth' \n \ + }); \n \ + scrollToObject.style.backgroundColor = 'yellow'; \n \ + scrollToObject.style.fontSize = '20px'; \n \ + // change content for 1 second \n \ + setTimeout(function () { \n \ + scrollToObject.style.backgroundColor = 'transparent'; \n \ + scrollToObject.style.fontSize = '12px'; \n \ + }, 1000); \n \ +} \n \ +"; + +const string StmtToViz::expand_code_js = "\n \ +// expand code div\n \ +var codeDiv = document.getElementById('IRCode-code'); \n \ +var resizeBar = document.getElementById('ResizeBar'); \n \ +var irVizDiv = document.getElementById('IRVisualization'); \n \ +var resizeBarAssembly = document.getElementById('ResizeBarAssembly'); \n \ +var assemblyCodeDiv = document.getElementById('assemblyCode'); \n \ + \n \ +codeDiv.style.flexGrow = '0'; \n \ +resizeBar.style.flexGrow = '0'; \n \ +irVizDiv.style.flexGrow = '0'; \n \ +resizeBarAssembly.style.flexGrow = '0'; \n \ +assemblyCodeDiv.style.flexGrow = '0'; \n \ + \n \ +codeDiv.style.flexBasis = 'calc(50% - 6px)'; \n \ +resizeBar.style.flexBasis = '6px'; \n \ +irVizDiv.style.flexBasis = 'calc(50% - 3px)'; \n \ +resizeBarAssembly.style.flexBasis = '6px'; \n \ + \n \ +resizeBar.addEventListener('mousedown', (event) => { \n \ + document.addEventListener('mousemove', resize, false); \n \ + document.addEventListener('mouseup', () => { \n \ + document.removeEventListener('mousemove', resize, false); \n \ + }, false); \n \ +}); \n \ + \n \ +resizeBarAssembly.addEventListener('mousedown', (event) => { \n \ + document.addEventListener('mousemove', resizeAssembly, false); \n \ + document.addEventListener('mouseup', () => { \n \ + document.removeEventListener('mousemove', resizeAssembly, false); \n \ + }, false); \n \ +}); \n \ +function resize(e) { \n \ + if (e.x < 25) { \n \ + collapseCode(); \n \ + return; \n \ + } \n \ + \n \ + const size = `${e.x}px`; \n \ + var rect = resizeBarAssembly.getBoundingClientRect(); \n \ + \n \ + if (e.x > rect.left) { \n \ + collapseViz(); \n \ + return; \n \ + } \n \ + \n \ + codeDiv.style.display = 'block'; \n \ + irVizDiv.style.display = 'block'; \n \ + codeDiv.style.flexBasis = size; \n \ + irVizDiv.style.flexBasis = `calc(${rect.left}px - ${size})`; \n \ +} \n \ +function resizeAssembly(e) { \n \ + if (e.x > screen.width - 25) { \n \ + collapseAssembly(); \n \ + return; \n \ + } \n \ + \n \ + var rect = resizeBar.getBoundingClientRect(); \n \ + \n \ + if (e.x < rect.right) {\n \ + collapseViz();\n \ + return;\n \ + }\n \ + \n \ + const size = `${e.x}px`; \n \ + irVizDiv.style.display = 'block'; \n \ + assemblyCodeDiv.style.display = 'block'; \n \ + irVizDiv.style.flexBasis = `calc(${size} - ${rect.right}px)`; \n \ + assemblyCodeDiv.style.flexBasis = `calc(100% - ${size})`; \n \ + \n \ +} \n \ +function collapseCode() { \n \ + irVizDiv.style.display = 'block'; \n \ + var rect = resizeBarAssembly.getBoundingClientRect(); \n \ + irVizDiv.style.flexBasis = `${rect.left}px`; \n \ + codeDiv.style.display = 'none'; \n \ +} \n \ +function collapseViz() { \n \ + codeDiv.style.display = 'block'; \n \ + var rect = resizeBarAssembly.getBoundingClientRect(); \n \ + codeDiv.style.flexBasis = `${rect.left}px`; \n \ + irVizDiv.style.display = 'none'; \n \ +} \n \ +function collapseVizAssembly() { \n \ + assemblyCodeDiv.style.display = 'block'; \n \ + var rect = resizeBar.getBoundingClientRect(); \n \ + assemblyCodeDiv.style.flexBasis = `calc(100% - ${rect.right}px)`; \n \ + irVizDiv.style.display = 'none'; \n \ +} \n \ +function collapseAssembly() { \n \ + irVizDiv.style.display = 'block'; \n \ + var rect = resizeBar.getBoundingClientRect(); \n \ + irVizDiv.style.flexBasis = `calc(100% - ${rect.right}px)`; \n \ + assemblyCodeDiv.style.display = 'none'; \n \ +} \n \ +"; + +const string StmtToViz::code_mirror_js = "\n \ +// CodeMirror \n \ +function jumpToLine(myCodeMirror, start, end) {\n \ + start -= 1;\n \ + end -= 1;\n \ + var t = myCodeMirror.charCoords({ line: start, ch: 0 }, 'local').top;\n \ + var middleHeight = myCodeMirror.getScrollerElement().offsetHeight / 2;\n \ + myCodeMirror.scrollIntoView({ line: start+40, ch: 0 });\n \ + for(var i = start; i <= end; i++) {\n \ + myCodeMirror.markText({ line: i, ch: 0 }, { line: i, ch: 200 }, { className: 'styled-background' });\n \ + }\n \ + myCodeMirror.markText({ line: start, ch: 0 }, { line: start, ch: 200 }, { className: 'styled-background' });\n \ + myCodeMirror.markText({ line: end, ch: 0 }, { line: end, ch: 200 }, { className: 'styled-background' });\n \ + myCodeMirror.focus();\n \ + myCodeMirror.setCursor({line: start, ch: 0});\n \ +}\n \ +function populateCodeMirror(lineNumStart, lineNumberEnd) { \n \ + assemblyCodeDiv.style.display = 'block'; \n \ + var codeHTML = document.getElementById('assemblyContent'); \n \ + var code = codeHTML.textContent; \n \ + code = code.trimLeft(); \n \ + document.getElementById('assemblyCode').innerHTML = ''; \n \ + var myCodeMirror = CodeMirror(document.getElementById('assemblyCode'), { value: code, lineNumbers: true, lineWrapping: true, mode: { name: 'gas', architecture: 'ARMv6' }, readOnly: true, }); \n \ + jumpToLine(myCodeMirror, lineNumStart, lineNumberEnd); \n \ + document.getElementsByClassName('CodeMirror-sizer')[0].style.minWidth = '0px'; \n \ +} \n \ +populateCodeMirror(1, 1); \n \ +collapseAssembly(); \n \ +"; + +void print_to_viz(const string &filename, const Stmt &s) { + internal_assert(false) << "\n\n" + << "Exiting early: print_to_viz cannot be called from a Stmt node - it must be " + "called from a Module node.\n" + << "\n\n\n"; +} + +void print_to_viz(const string &filename, const Module &m) { + + StmtToViz sth(filename, m); + + sth.generate_html(filename, m); + debug(1) << "Done generating HTML IR Visualization - printed to: " << filename << "\n"; +} + +} // namespace Internal +} // namespace Halide diff --git a/src/StmtToViz.h b/src/StmtToViz.h new file mode 100644 index 000000000000..d3ca3f782175 --- /dev/null +++ b/src/StmtToViz.h @@ -0,0 +1,31 @@ +#ifndef HALIDE_STMT_TO_VIZ +#define HALIDE_STMT_TO_VIZ + +/** \file + * Defines a function to dump an HTML-formatted visualization to a file. + */ + +#include + +namespace Halide { + +class Module; + +namespace Internal { + +struct Stmt; + +/** + * Dump an HTML-formatted visualization of a Stmt to filename. + */ +void print_to_viz(const std::string &filename, const Stmt &s); + +/** Dump an HTML-formatted visualization of a Module to filename. */ +void print_to_viz(const std::string &filename, const Module &m); + +extern const char *StmtToViz_canIgnoreVariableName_string; + +} // namespace Internal +} // namespace Halide + +#endif