diff --git a/CMakeLists.txt b/CMakeLists.txt index 0eba24f61d14..6d329f5f1079 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -633,6 +633,10 @@ if(USE_CUDA) else() list(APPEND CUDA_INCLUDE_DIRS ${INCLUDE_DIRECTORIES}) # define preprocessor macro so that we will not include the generated forcelink header + if(ENABLE_CUDA_RTC) + add_definitions(-DMXNET_ENABLE_CUDA_RTC=1) + endif() + # Create '.cmake' files for cuda compiles given definitions added thus far mshadow_cuda_compile(cuda_objs ${CUDA}) if(MSVC) if(ENABLE_CUDA_RTC) @@ -640,7 +644,6 @@ if(USE_CUDA) list(APPEND mxnet_LINKER_LIBS ${CUDA_nvrtc_LIBRARY}) set(CUDA_cuda_LIBRARY "${CUDA_nvrtc_LIBRARY}/../cuda.lib") list(APPEND mxnet_LINKER_LIBS ${CUDA_cuda_LIBRARY}) - add_definitions(-DMXNET_ENABLE_CUDA_RTC=1) endif() FIND_LIBRARY(CUDA_cufft_LIBRARY nvrtc "${CUDA_TOOLKIT_ROOT_DIR}/lib/x64" "${CUDA_TOOLKIT_ROOT_DIR}/lib/win32") list(APPEND mxnet_LINKER_LIBS "${CUDA_cufft_LIBRARY}/../cufft.lib") # For fft operator @@ -652,7 +655,6 @@ if(USE_CUDA) list(APPEND mxnet_LINKER_LIBS cufft cusolver) if(ENABLE_CUDA_RTC) list(APPEND mxnet_LINKER_LIBS nvrtc cuda) - add_definitions(-DMXNET_ENABLE_CUDA_RTC=1) endif() link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64") endif() diff --git a/appveyor.yml b/appveyor.yml index d44f52a0a9a9..9fa495002a1f 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -69,7 +69,7 @@ before_build: set OpenCV_DIR=%APPVEYOR_BUILD_FOLDER%/%MXNET_OPENCV_DIR%/build - cmake .. -DOPENCV_DIR=%OpenCV_DIR% -DUSE_PROFILER=1 -DUSE_CUDA=0 -DUSE_CUDNN=0 -DUSE_NVRTC=0 -DUSE_OPENCV=1 -DUSE_OPENMP=1 -DUSE_BLAS=open -DUSE_LAPACK=1 -DUSE_DIST_KVSTORE=0 -G "Visual Studio 12 2013 Win64" + cmake .. -DOPENCV_DIR=%OpenCV_DIR% -DUSE_PROFILER=1 -DUSE_CUDA=0 -DUSE_CUDNN=0 -DENABLE_CUDA_RTC=0 -DUSE_OPENCV=1 -DUSE_OPENMP=1 -DUSE_BLAS=open -DUSE_LAPACK=1 -DUSE_DIST_KVSTORE=0 -G "Visual Studio 12 2013 Win64" build_script: - cmd: >- diff --git a/ci/build_windows.py b/ci/build_windows.py index 4673bd535e3e..ce77c316ab20 100755 --- a/ci/build_windows.py +++ b/ci/build_windows.py @@ -54,7 +54,7 @@ class BuildFlavour(Enum): 'WIN_CPU': ( '-DUSE_CUDA=OFF ' '-DUSE_CUDNN=OFF ' - '-DUSE_NVRTC=OFF ' + '-DENABLE_CUDA_RTC=OFF ' '-DUSE_OPENCV=ON ' '-DUSE_OPENMP=ON ' '-DUSE_BLAS=open ' @@ -67,7 +67,7 @@ class BuildFlavour(Enum): , 'WIN_CPU_MKLDNN': ( '-DUSE_CUDA=OFF ' '-DUSE_CUDNN=OFF ' - '-DUSE_NVRTC=OFF ' + '-DENABLE_CUDA_RTC=OFF ' '-DUSE_OPENCV=ON ' '-DUSE_OPENMP=ON ' '-DUSE_BLAS=open ' @@ -80,7 +80,7 @@ class BuildFlavour(Enum): , 'WIN_CPU_MKLDNN_MKL': ( '-DUSE_CUDA=OFF ' '-DUSE_CUDNN=OFF ' - '-DUSE_NVRTC=OFF ' + '-DENABLE_CUDA_RTC=OFF ' '-DUSE_OPENCV=ON ' '-DUSE_OPENMP=ON ' '-DUSE_BLAS=mkl ' @@ -93,7 +93,7 @@ class BuildFlavour(Enum): , 'WIN_CPU_MKL': ( '-DUSE_CUDA=OFF ' '-DUSE_CUDNN=OFF ' - '-DUSE_NVRTC=OFF ' + '-DENABLE_CUDA_RTC=OFF ' '-DUSE_OPENCV=ON ' '-DUSE_OPENMP=ON ' '-DUSE_BLAS=mkl ' @@ -106,7 +106,7 @@ class BuildFlavour(Enum): , 'WIN_GPU': ( '-DUSE_CUDA=ON ' '-DUSE_CUDNN=ON ' - '-DUSE_NVRTC=ON ' + '-DENABLE_CUDA_RTC=ON ' '-DUSE_OPENCV=ON ' '-DUSE_OPENMP=ON ' '-DUSE_BLAS=open ' @@ -122,7 +122,7 @@ class BuildFlavour(Enum): , 'WIN_GPU_MKLDNN': ( '-DUSE_CUDA=ON ' '-DUSE_CUDNN=ON ' - '-DUSE_NVRTC=ON ' + '-DENABLE_CUDA_RTC=ON ' '-DUSE_OPENCV=ON ' '-DUSE_OPENMP=ON ' '-DUSE_BLAS=open ' diff --git a/make/maven/maven_darwin_mkl.mk b/make/maven/maven_darwin_mkl.mk index a7f2bdb027d4..9bf3fc46ce0b 100644 --- a/make/maven/maven_darwin_mkl.mk +++ b/make/maven/maven_darwin_mkl.mk @@ -77,7 +77,7 @@ USE_CUDNN = 0 # CUDA_ARCH := # whether use cuda runtime compiling for writing kernels in native language (i.e. Python) -USE_NVRTC = 0 +ENABLE_CUDA_RTC = 0 # use openmp for parallelization USE_OPENMP = 0 diff --git a/make/maven/maven_linux_cu90mkl.mk b/make/maven/maven_linux_cu90mkl.mk index e9ba46509973..e8caf73f186e 100644 --- a/make/maven/maven_linux_cu90mkl.mk +++ b/make/maven/maven_linux_cu90mkl.mk @@ -80,7 +80,7 @@ USE_NCCL = 1 # whether use cuda runtime compiling for writing kernels in native language (i.e. Python) USE_NVTX=1 -USE_NVRTC = 1 +ENABLE_CUDA_RTC = 1 # use openmp for parallelization USE_OPENMP = 1 diff --git a/make/maven/maven_linux_cu92mkl.mk b/make/maven/maven_linux_cu92mkl.mk index caa1c59c01d5..930341e71cb1 100644 --- a/make/maven/maven_linux_cu92mkl.mk +++ b/make/maven/maven_linux_cu92mkl.mk @@ -80,7 +80,7 @@ USE_NCCL = 1 # whether use cuda runtime compiling for writing kernels in native language (i.e. Python) USE_NVTX=1 -USE_NVRTC = 1 +ENABLE_CUDA_RTC = 1 # use openmp for parallelization USE_OPENMP = 1 diff --git a/make/maven/maven_linux_mkl.mk b/make/maven/maven_linux_mkl.mk index 3c8534a7e2aa..10aee5f35a46 100644 --- a/make/maven/maven_linux_mkl.mk +++ b/make/maven/maven_linux_mkl.mk @@ -76,7 +76,7 @@ USE_CUDNN = 0 # CUDA_ARCH := # whether use cuda runtime compiling for writing kernels in native language (i.e. Python) -USE_NVRTC = 0 +ENABLE_CUDA_RTC = 0 # use openmp for parallelization USE_OPENMP = 1 diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index a5f125affcb0..55d431cf3298 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -221,6 +221,11 @@ Graph FusePointwiseForward(Graph&& g); */ Graph FusePointwiseBackward(Graph&& g); +/*! + * \brief Issue a one-time warning that fusion is not possible for this platform or build. + */ +void WarnFusionNotSupported(); + /*! * \brief Infer shapes in the graph given the information. * \param graph The input graph. diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 7fa1de373d07..508fbba97be3 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -50,7 +50,7 @@ static const std::string GetDefaultSubgraphBackend() { #endif } -GraphExecutor::GraphExecutor() { +GraphExecutor::GraphExecutor(const nnvm::Symbol& symbol) { log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false); need_grad_ = false; is_dynamic_ = false; @@ -60,6 +60,7 @@ GraphExecutor::GraphExecutor() { LOG(INFO) << "MXNET_SUBGRAPH_BACKEND=NONE is detected, subgraph backend is not in use"; } engine_ref_ = Engine::_GetSharedRef(); + symbol_ = symbol.Copy(); } GraphExecutor::~GraphExecutor() { @@ -890,10 +891,9 @@ Executor* GraphExecutor::Reshape(const bool partial_shaping, std::vector* arg_grads, std::vector* aux_states) { nnvm::Graph g; - g.outputs = std::vector(graph_.outputs.begin(), - graph_.outputs.begin() + num_forward_outputs_); nnvm::Symbol symbol; - symbol.outputs = g.outputs; + symbol.outputs = symbol_.outputs; + g.outputs = symbol_.outputs; const nnvm::IndexedGraph& idx = g.indexed_graph(); mxnet::ShapeVector arg_shapes(idx.input_nodes().size(), mxnet::TShape()); for (size_t i = 0; i < num_forward_inputs_; ++i) { @@ -977,8 +977,8 @@ Executor* GraphExecutor::Reshape(const bool partial_shaping, } } } - auto exec = new GraphExecutor(); - exec->Init(symbol, default_ctx, ctx_map, + auto exec = new GraphExecutor(symbol); + exec->Init(symbol.Copy(), default_ctx, ctx_map, *in_args, *arg_grads, grad_req_types, *aux_states, this); return exec; @@ -1001,7 +1001,7 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, // setup gradient nnvm::Graph g = InitFullGraph(symbol, grad_req_types); -#if MXNET_USE_CUDA && !defined(_WIN32) +#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC && !defined(_WIN32) if (default_ctx.dev_mask() == Context::kGPU && dmlc::GetEnv("MXNET_USE_FUSION", true)) { nnvm::Graph unoptimized_graph; common::CopyGraph(&unoptimized_graph, g, false); @@ -1034,7 +1034,12 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, << "Graph contains duplicate names for some of its inputs - fusion is NOT enabled!"; } } -#endif // MXNET_USE_CUDA +#else + // Only warn user if MXNET_USE_FUSION env var is explicitly set + if (default_ctx.dev_mask() == Context::kGPU && dmlc::GetEnv("MXNET_USE_FUSION", false)) { + WarnFusionNotSupported(); + } +#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC && !defined(_WIN32) // create "device" and "context" attrs for the graph g = AssignContext(g, default_ctx, ctx_map, @@ -1969,7 +1974,7 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, std::vector* aux_states, std::unordered_map* shared_buffer, Executor* shared_exec) { - auto exec = new exec::GraphExecutor(); + auto exec = new exec::GraphExecutor(symbol); bool init = false; if (!exec->subgraph_property().empty()) { static int verbose = dmlc::GetEnv("MXNET_SUBGRAPH_VERBOSE", 1); @@ -1989,6 +1994,8 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, symbol = exec::BuildSubgraph(symbol, backend, arg_shape_map, arg_dtype_map, arg_stype_map, default_ctx, group2ctx, &tmp_in_arg_ctxes, &tmp_arg_grad_ctxes, &tmp_grad_req_types, &tmp_aux_state_ctxes, verbose); + // Subgraph cannot be recreated from unoptimized symbol + exec = new exec::GraphExecutor(symbol); exec->Init(symbol.Copy(), default_ctx, group2ctx, tmp_in_arg_ctxes, tmp_arg_grad_ctxes, tmp_aux_state_ctxes, arg_shape_map, arg_dtype_map, arg_stype_map, tmp_grad_req_types, shared_arg_names, &tmp_in_args, &tmp_arg_grads, @@ -2043,7 +2050,7 @@ Executor *Executor::Bind(nnvm::Symbol symbol, const std::vector &grad_req_type, const std::vector &aux_states, Executor* shared_exec) { - auto exec = new exec::GraphExecutor(); + auto exec = new exec::GraphExecutor(symbol); static int verbose = dmlc::GetEnv("MXNET_SUBGRAPH_VERBOSE", 1); std::vector tmp_in_args = in_args; std::vector tmp_arg_grad_store = arg_grad_store; @@ -2058,6 +2065,8 @@ Executor *Executor::Bind(nnvm::Symbol symbol, symbol = exec::BuildSubgraph(symbol, backend, default_ctx, group2ctx, &tmp_in_args, &tmp_arg_grad_store, &tmp_grad_req_type, &tmp_aux_states, verbose); + // Subgraph cannot be recreated from unoptimized symbol + exec = new exec::GraphExecutor(symbol); } } exec->Init(symbol.Copy(), default_ctx, group2ctx, tmp_in_args, tmp_arg_grad_store, diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index f150165796ad..bfa6980a8e29 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -58,7 +58,7 @@ class GraphExecutor : public Executor { public: using Executor::MonitorCallback; - GraphExecutor(); + explicit GraphExecutor(const nnvm::Symbol& symbol); virtual ~GraphExecutor(); void Forward(bool is_train) override; void PartialForward(bool is_train, int step, int *step_left) override; @@ -267,6 +267,9 @@ class GraphExecutor : public Executor { std::string subgraph_property_; // ref of engine std::shared_ptr engine_ref_; + // Unoptimized copy of the symbol for sharing with + // child executors + nnvm::Symbol symbol_; }; } // namespace exec diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc index 80e4084c478e..4b6ee2e1dc0d 100644 --- a/src/executor/infer_graph_attr_pass.cc +++ b/src/executor/infer_graph_attr_pass.cc @@ -67,6 +67,7 @@ template inline void GetAttrFromForwardNode(const uint32_t nid, const nnvm::IndexedGraph &idx, std::vector* rshape_ptr, + std::vector* inference_finished, IsNone fis_none) { std::vector& rshape = *rshape_ptr; const nnvm::IndexedGraph::Node& inode = idx[nid]; @@ -83,18 +84,23 @@ inline void GetAttrFromForwardNode(const uint32_t nid, // input gradient list const std::vector& igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd); const nnvm::Node* igrad_node = nullptr; + bool all_attrs_known = true; // Input gradient assignement for (size_t i = 0; i < igrad.size(); ++i) { if (igrad[i].node->op() == inode.source->op()) { uint32_t eid = idx.entry_id(nid, igrad[i].index); - if (fis_none(rshape[eid])) { - rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])]; - } else if (!fis_none(rshape[idx.entry_id(fnode.inputs[i])])) { + if (fis_none(rshape[idx.entry_id(fnode.inputs[i])])) { // Need to skip empty forward shape, because it may not be // available now and it is possible to infer the forward // shape in one of the next a few passes - CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])]) - << "Backward shape inconsistent with the forward shape"; + all_attrs_known = false; + } else { + if (fis_none(rshape[eid])) { + rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])]; + } else { + CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])]) + << "Backward shape inconsistent with the forward shape"; + } } if (igrad_node == nullptr) { igrad_node = igrad[i].node.get(); @@ -113,14 +119,20 @@ inline void GetAttrFromForwardNode(const uint32_t nid, if (fis_none(rshape[eid])) { rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)]; } + if (fis_none(rshape[eid])) { + // If the attr is still unknown + all_attrs_known = false; + } } } + (*inference_finished)[nid] = all_attrs_known; } template void GetAttrFromFusedNode(uint32_t nid, const nnvm::IndexedGraph& idx, std::vector* rshape_ptr, + std::vector* inference_finished, IsNone fis_none, const std::string& infer_fusion_name) { std::vector& rshape = *rshape_ptr; @@ -147,19 +159,24 @@ void GetAttrFromFusedNode(uint32_t nid, // input gradient list const std::vector& igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd); const nnvm::Node* igrad_node = nullptr; + bool all_attrs_known = true; // Set the attributes of output gradients // using attributes of forward node inputs for (size_t i = 0; i < igrad.size(); ++i) { if (igrad[i].node->op() == inode.source->op()) { uint32_t eid = idx.entry_id(nid, igrad[i].index); - if (fis_none(rshape[eid])) { - rshape[eid] = input_attrs[i]; - } else if (!fis_none(input_attrs[i])) { + if (fis_none(input_attrs[i])) { // Need to skip empty forward shape, because it may not be // available now and it is possible to infer the forward // shape in one of the next a few passes - CHECK_EQ(rshape[eid], input_attrs[i]) - << "Backward shape inconsistent with the forward shape"; + all_attrs_known = false; + } else { + if (fis_none(rshape[eid])) { + rshape[eid] = input_attrs[i]; + } else { + CHECK_EQ(rshape[eid], input_attrs[i]) + << "Backward shape inconsistent with the forward shape"; + } } if (igrad_node == nullptr) { igrad_node = igrad[i].node.get(); @@ -180,8 +197,13 @@ void GetAttrFromFusedNode(uint32_t nid, if (fis_none(rshape[eid])) { rshape[eid] = output_attrs[e.index]; } + if (fis_none(rshape[eid])) { + // If the attr is still unknown + all_attrs_known = false; + } } } + (*inference_finished)[nid] = all_attrs_known; } template @@ -270,6 +292,9 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, Op::GetAttr("TIsBackward"); // reshape shape vector AttrVector rshape; + // vector holding information which operators + // finished attribute inference + std::vector inference_finished(idx.num_nodes(), false); // dispatch mode vector DispatchModeVector dispatch_modes; if (ret.attrs.count(attr_name) != 0) { @@ -340,6 +365,7 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, // inference step function for nid auto infer_step = [&](uint32_t nid, bool last_iter) { + if (inference_finished[nid]) return; const auto& inode = idx[nid]; const uint32_t num_inputs = inode.inputs.size(); const uint32_t num_outputs = inode.source->num_outputs(); @@ -355,6 +381,9 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, CHECK(is >> rshape[out_ent_id]) << "Invalid attribute"; } } + if (!fis_none(rshape[out_ent_id])) { + inference_finished[nid] = true; + } // assign a default value to node attribute if (dispatch_mode_name != nullptr) { op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val); @@ -370,47 +399,66 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, static auto& is_fusion_helper = Op::GetAttr("TIsFusionHelper"); if (!is_fusion_helper.get(fwd_ptr->op(), false)) { - GetAttrFromForwardNode(nid, idx, &rshape, fis_none); + GetAttrFromForwardNode(nid, idx, &rshape, &inference_finished, fis_none); } else { - GetAttrFromFusedNode(nid, idx, &rshape, fis_none, infer_fusion_name); + GetAttrFromFusedNode(nid, idx, &rshape, &inference_finished, + fis_none, infer_fusion_name); } } else { DispatchMode* dispatch_mode = nullptr; - bool forward_known = true; // Forward operator inference. ishape.resize(num_inputs, empty_val); for (uint32_t i = 0; i < ishape.size(); ++i) { ishape[i] = rshape[idx.entry_id(inode.inputs[i])]; - if (fis_none(ishape[i])) forward_known = false; } oshape.resize(num_outputs, empty_val); for (uint32_t i = 0; i < oshape.size(); ++i) { oshape[i] = rshape[idx.entry_id(nid, i)]; - if (fis_none(oshape[i])) forward_known = false; } if (dispatch_mode_name != nullptr) { dispatch_mode = &dispatch_modes[nid]; - if (dispatch_modes[nid] == DispatchMode::kUndefined) forward_known = false; } auto finfer = finfer_shape.get(inode.source->op(), fdefault); - if (!forward_known) { - if (finfer != nullptr) { - // Call inference function of the operator. - try { - static auto& is_fusion = Op::GetAttr("TIsFusion"); - if (is_fusion.get(inode.source->op(), false)) { - ProvideAttrToFusion(nid, idx, rshape, provide_fusion_name); - } - forward_known = ApplyOpInferAttr(ret, finfer, inode.source->attrs, - nid, &ishape, &oshape, dispatch_mode); - } catch (const std::exception& e) { - throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); + if (finfer != nullptr) { + // Call inference function of the operator. + try { + static auto& is_fusion = Op::GetAttr("TIsFusion"); + if (is_fusion.get(inode.source->op(), false)) { + ProvideAttrToFusion(nid, idx, rshape, provide_fusion_name); } - } else { + ApplyOpInferAttr(ret, finfer, inode.source->attrs, + nid, &ishape, &oshape, dispatch_mode); + bool finished = true; + for (const auto& attr : ishape) { + if (fis_none(attr)) finished = false; + } + for (const auto& attr : oshape) { + if (fis_none(attr)) finished = false; + } + inference_finished[nid] = finished; + } catch (const std::exception& e) { + throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); + } + } else { + // Operator does not provide sttribute inference function, + // so we need to test if everything was inferred by other operators + bool all_attrs_known = true; + for (const auto& attr : ishape) { + if (fis_none(attr)) { + all_attrs_known = false; + } + } + for (const auto& attr : oshape) { + if (fis_none(attr)) { + all_attrs_known = false; + } + } + inference_finished[nid] = all_attrs_known; + if (!all_attrs_known) { CHECK(!last_iter) << "Attribute " << infer_name - << " is not registed by op " << inode.source->op()->name - << " we are not able to complete the inference because of this"; + << " is not registered by op " << inode.source->op()->name + << ". We are not able to complete the inference because of this"; } } // Save to the result map. @@ -427,16 +475,18 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, size_t num_unknown_dispatch_mode = dispatch_mode_name ? node_end - node_start : 0; size_t num_unknown_entry_attr = entry_end - entry_start; size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode; + bool last_iter = false; + bool do_next_iteration = true; int i = 0; do { if (i % 2 == 0) { for (uint32_t nid = node_start; nid < node_end; ++nid) { - infer_step(nid, false); + infer_step(nid, last_iter); } } else { // backward inference for (uint32_t i = node_end; i != node_start; --i) { - infer_step(i - 1, false); + infer_step(i - 1, last_iter); } } last_num_unknown = num_unknown; @@ -451,8 +501,18 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, if (dispatch_modes[i] == DispatchMode::kUndefined) ++num_unknown; } } + do_next_iteration = num_unknown > 0 && last_num_unknown > num_unknown; + if (!do_next_iteration && !last_iter) { + // Check if every op agrees that it should be + // the end of attribute inference. If not, + // perform one final step + for (const bool done : inference_finished) { + do_next_iteration = do_next_iteration || !done; + } + last_iter = true; + } ++i; - } while (num_unknown > 0 && last_num_unknown > num_unknown); + } while (do_next_iteration); // set the shapes ret.attrs[attr_name] = std::make_shared(std::move(rshape)); // set the shapes @@ -517,6 +577,9 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, Op::GetAttr("TIsBackward"); // reshape shape vector AttrVector rshape; + // vector holding information which operators + // finished attribute inference + std::vector inference_finished(idx.num_nodes(), false); // dispatch mode vector DispatchModeVector dispatch_modes; if (ret.attrs.count(attr_name) != 0) { @@ -594,6 +657,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, // inference step function for nid auto infer_step = [&](uint32_t nid, bool last_iter) { + if (inference_finished[nid]) return; const auto& inode = idx[nid]; const std::string name = inode.source->attrs.name; const uint32_t num_inputs = inode.inputs.size(); @@ -613,6 +677,9 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, } } } + if (!fis_none(rshape[out_ent_id])) { + inference_finished[nid] = true; + } // assign a default value to node attribute if (dispatch_mode_name != nullptr) { op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val); @@ -628,14 +695,15 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, static auto& is_fusion_helper = Op::GetAttr("TIsFusionHelper"); if (!is_fusion_helper.get(fwd_ptr->op(), false)) { - GetAttrFromForwardNode(nid, idx, &rshape, fis_none); + GetAttrFromForwardNode(nid, idx, &rshape, &inference_finished, fis_none); } else { - GetAttrFromFusedNode(nid, idx, &rshape, fis_none, + GetAttrFromFusedNode(nid, idx, &rshape, + &inference_finished, + fis_none, "FAccessSubgraphShape"); } } else { DispatchMode* dispatch_mode = nullptr; - bool forward_known = true; // Forward operator inference. ishape.resize(num_inputs, empty_val); bool is_input_dynamic_shape = false; @@ -644,16 +712,13 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, if (!mxnet::ndim_is_known(ishape[i]) && is_dynamic[idx.entry_id(inode.inputs[i])]) { is_input_dynamic_shape = true; } - if (fis_none(ishape[i])) forward_known = false; } oshape.resize(num_outputs, empty_val); for (uint32_t i = 0; i < oshape.size(); ++i) { oshape[i] = rshape[idx.entry_id(nid, i)]; - if (fis_none(oshape[i])) forward_known = false; } if (dispatch_mode_name != nullptr) { dispatch_mode = &dispatch_modes[nid]; - if (dispatch_modes[nid] == DispatchMode::kUndefined) forward_known = false; } auto finfer = finfer_shape.get(inode.source->op(), fdefault); if (finfer == nullptr || is_input_dynamic_shape) { @@ -662,25 +727,27 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, is_dynamic[idx.entry_id(nid, i)] = 1; } } - } else if (!forward_known) { - if (finfer != nullptr) { - // Call inference function of the operator. - try { - static auto& is_fusion = Op::GetAttr("TIsFusion"); - if (is_fusion.get(inode.source->op(), false)) { - ProvideAttrToFusion(nid, idx, rshape, - "FProvideSubgraphShape"); - } - forward_known = ApplyOpInferAttr(ret, finfer, inode.source->attrs, - nid, &ishape, &oshape, dispatch_mode); - } catch (const std::exception& e) { - throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); + inference_finished[nid] = true; + } else { + // Call inference function of the operator. + try { + static auto& is_fusion = Op::GetAttr("TIsFusion"); + if (is_fusion.get(inode.source->op(), false)) { + ProvideAttrToFusion(nid, idx, rshape, + "FProvideSubgraphShape"); } - } else { - CHECK(!last_iter) - << "Attribute " << infer_name - << " is not registed by op " << inode.source->op()->name - << " we are not able to complete the inference because of this"; + ApplyOpInferAttr(ret, finfer, inode.source->attrs, + nid, &ishape, &oshape, dispatch_mode); + bool finished = true; + for (const auto& attr : ishape) { + if (fis_none(attr)) finished = false; + } + for (const auto& attr : oshape) { + if (fis_none(attr)) finished = false; + } + inference_finished[nid] = finished; + } catch (const std::exception& e) { + throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); } } // Save to the result map. @@ -695,18 +762,20 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, size_t last_num_unknown; size_t num_unknown = static_cast(-1); // Infinity + bool last_iter = false; + bool do_next_iteration = true; int i = 0; do { if (i % 2 == 0) { // forward inference for (uint32_t nid = node_start; nid < node_end; ++nid) { - infer_step(nid, false); + infer_step(nid, last_iter); } } else { // backward inference for (uint32_t i = node_end; i != node_start; --i) { - infer_step(i - 1, false); + infer_step(i - 1, last_iter); } } last_num_unknown = num_unknown; @@ -723,8 +792,18 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, } } } + do_next_iteration = num_unknown > 0 && last_num_unknown > num_unknown; + if (!do_next_iteration && !last_iter) { + // Check if every op agrees that it should be + // the end of attribute inference. If not, + // perform one final step + for (const bool done : inference_finished) { + do_next_iteration = do_next_iteration || !done; + } + last_iter = true; + } ++i; - } while (num_unknown > 0 && last_num_unknown > num_unknown); + } while (do_next_iteration); // set the shapes ret.attrs[attr_name] = std::make_shared(std::move(rshape)); // set the shapes diff --git a/src/executor/pointwise_fusion_pass.cc b/src/executor/pointwise_fusion_pass.cc index 6fe21402cb3a..6a0d5f4efe87 100644 --- a/src/executor/pointwise_fusion_pass.cc +++ b/src/executor/pointwise_fusion_pass.cc @@ -36,10 +36,26 @@ #include "../operator/fusion/fused_op.h" #include "../operator/operator_common.h" -#if MXNET_USE_CUDA - namespace mxnet { namespace exec { + +void WarnFusionNotSupported() { + static bool issued_warning = false; + if (!issued_warning) { + issued_warning = true; +#if defined(_WIN32) + LOG(WARNING) << "Omitting dynamic fused op creation- not enabled on Windows. " + << "Unset env var MXNET_USE_FUSION=1 to quiet this message."; +#else + LOG(WARNING) << "Omitting dynamic fused op creation- needs MXNet lib built with " + << "USE_CUDA=1 and ENABLE_CUDA_RTC=1. Unset env var MXNET_USE_FUSION=1 " + << "to quiet this message."; +#endif // defined(_WIN32) + } +} + +#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC + namespace { bool IsFusionCompatible(nnvm::Node* n) { using namespace mxnet::fusion; @@ -304,8 +320,8 @@ Graph FusePointwiseBackward(Graph &&g) { ret.outputs = g.outputs; return ret; } +#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC } // namespace exec } // namespace mxnet -#endif // MXNET_USE_CUDA diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 269729c18f58..24270f210888 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -167,10 +167,8 @@ void SetRefCounts(nnvm::Graph* fwd_graph, const nnvm::Graph& full_graph) { void OptimizeGraph(nnvm::Graph * full_graph, nnvm::Graph * fwd_graph, nnvm::Graph * grad_graph, const Context& context, size_t num_forward_outputs, const bool inlining) { -#if MXNET_USE_CUDA && !defined(_WIN32) - if (context.dev_mask() == kGPU && - !inlining && - dmlc::GetEnv("MXNET_USE_FUSION", true)) { +#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC && !defined(_WIN32) + if (context.dev_mask() == kGPU && !inlining && dmlc::GetEnv("MXNET_USE_FUSION", true)) { nnvm::Graph unoptimized_graph; common::CopyGraph(&unoptimized_graph, *full_graph, false); @@ -202,7 +200,12 @@ void OptimizeGraph(nnvm::Graph * full_graph, nnvm::Graph * fwd_graph, nnvm::Grap << "Graph contains duplicate names for some of its inputs - fusion is NOT enabled!"; } } -#endif // MXNET_USE_CUDA +#else + // Only warn user if MXNET_USE_FUSION env var is explicitly set + if (context.dev_mask() == kGPU && !inlining && dmlc::GetEnv("MXNET_USE_FUSION", false)) { + exec::WarnFusionNotSupported(); + } +#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC && !defined(_WIN32) *fwd_graph = nnvm::Graph(); fwd_graph->outputs = std::vector(full_graph->outputs.begin(), diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index fd087ef39679..a9e9038e6c51 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -761,7 +761,7 @@ static bool WhileLoopType(const nnvm::NodeAttrs& attrs, std::vector func_in_type; extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type); extract_by_loc(*in_type, params.func_input_locs, &func_in_type); - std::vector cond_out_type = {0}; + std::vector cond_out_type = {-1}; CHECK(params.sync_in_out(in_type, out_type, is_udf)); bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type, &cond_out_type); CHECK(params.sync_in_out(in_type, out_type, is_udf)); diff --git a/src/operator/fusion/fused_op-inl.h b/src/operator/fusion/fused_op-inl.h index 2966fe2ae910..e86ce7682ad8 100644 --- a/src/operator/fusion/fused_op-inl.h +++ b/src/operator/fusion/fused_op-inl.h @@ -24,7 +24,7 @@ #include #include -#if MXNET_USE_CUDA +#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC namespace mxnet { @@ -992,6 +992,6 @@ const char kernel_end[] = R"code(} } // namespace mxnet -#endif // MXNET_USE_CUDA +#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC #endif // MXNET_OPERATOR_FUSION_FUSED_OP_INL_H_ diff --git a/src/operator/fusion/fused_op.cc b/src/operator/fusion/fused_op.cc index 5c83c30308c7..5e2d782dd9e0 100644 --- a/src/operator/fusion/fused_op.cc +++ b/src/operator/fusion/fused_op.cc @@ -23,7 +23,7 @@ #include "../operator_common.h" #include "../../executor/exec_pass.h" -#if MXNET_USE_CUDA +#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC namespace mxnet { @@ -302,4 +302,4 @@ NNVM_REGISTER_OP(_FusedOpOutHelper) } // namespace mxnet -#endif // MXNET_USE_CUDA +#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC diff --git a/src/operator/fusion/fused_op.cu b/src/operator/fusion/fused_op.cu index 78988f13510e..62f340d0e00b 100644 --- a/src/operator/fusion/fused_op.cu +++ b/src/operator/fusion/fused_op.cu @@ -17,6 +17,9 @@ * under the License. */ +// Additional use of MXNET_USE_CUDA is not needed to guard a '.cu' file. +#if MXNET_ENABLE_CUDA_RTC + #include #include #include @@ -787,3 +790,5 @@ NNVM_REGISTER_OP(_FusedOp) .set_attr("FCompute", FusedOpForwardGPU); } // namespace mxnet + +#endif // MXNET_ENABLE_CUDA_RTC diff --git a/src/operator/fusion/fused_op.h b/src/operator/fusion/fused_op.h index 24603ac1932f..7d714677e941 100644 --- a/src/operator/fusion/fused_op.h +++ b/src/operator/fusion/fused_op.h @@ -20,7 +20,6 @@ #ifndef MXNET_OPERATOR_FUSION_FUSED_OP_H_ #define MXNET_OPERATOR_FUSION_FUSED_OP_H_ - #include #include #include @@ -29,8 +28,7 @@ #include #include -#if MXNET_USE_CUDA - +#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC namespace mxnet { @@ -202,5 +200,6 @@ using FusedOpHelperParamPtr = std::shared_ptr; } // namespace mxnet -#endif // MXNET_USE_CUDA +#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC + #endif // MXNET_OPERATOR_FUSION_FUSED_OP_H_ diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 801e4e7126b4..0fee2a26c0ed 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -420,11 +420,9 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs, mxnet::TShape& shp = (*in_attrs)[0]; mxnet::TShape& out_shp = (*out_attrs)[0]; CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; - CHECK_NE(shp.ndim(), 0) << "Number of dimensions cannot be 0"; - CHECK_NE(out_shp.ndim(), 0) << "Number of dimensions cannot be 0"; if (shp.ndim() == -1 && out_shp.ndim() == -1) return false; // none of the shapes is known - if (out_shp.ndim() > 0 && shp.ndim() > 0) + if (out_shp.ndim() >= 0 && shp.ndim() >= 0) CHECK_EQ(out_shp.ndim(), shp.ndim()); mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1); mxnet::TShape ret(std::max(shp.ndim(), out_shp.ndim()), -1); diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py index 5606eb19a9c5..24e33019f617 100644 --- a/tests/python/gpu/test_fusion.py +++ b/tests/python/gpu/test_fusion.py @@ -172,7 +172,7 @@ def check_binary_ops(): check_fused_symbol(3-a, a=arr1) check_fused_symbol(a*b, a=arr1, b=arr2) check_fused_symbol(a*3, a=arr1) - check_fused_symbol(a/b, a=arr1, b=arr2) + check_fused_symbol(a/(b+1), a=arr1, b=arr2) check_fused_symbol(a/3, a=arr1) check_fused_symbol(3/a, a=arr1) check_fused_symbol(a**b, a=arr1, b=arr2) @@ -239,6 +239,31 @@ def test_fusion_compiler_cache(): check_fused_symbol(a+b, ctx=mx.gpu(1), a=arr1, b=arr2) +@with_seed() +def test_fusion_reshape_executor(): + a = mx.sym.Variable("data1") + b = mx.sym.Variable("data2") + c = a + b + 1 + sym = mx.sym.relu(c) + orig_shape = (10,10) + e = sym.simple_bind(ctx=mx.gpu(), data1=orig_shape, data2=orig_shape) + data = mx.nd.zeros(orig_shape, ctx=mx.gpu()) + out = e.forward(is_train=False) + assert out[0].sum().asscalar() == 100 + changed_shape = (80, 2) + new_shape = {'data1': changed_shape, 'data2': changed_shape} + data = mx.nd.zeros(new_shape['data1'], ctx=mx.gpu()) + f = e.reshape(allow_up_sizing=True, **new_shape) + out = f.forward(is_train=False, data1=data, data2=data) + assert out[0].sum().asscalar() == 160 + # Reshape again + changed_shape = (30, 5) + new_shape = {'data1': changed_shape, 'data2': changed_shape} + data = mx.nd.zeros(new_shape['data1'], ctx=mx.gpu()) + f = e.reshape(allow_up_sizing=True, **new_shape) + out = f.forward(is_train=False, data1=data, data2=data) + assert out[0].sum().asscalar() == 150 + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index a2aad2c079fc..8e4fe11905cf 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -413,7 +413,7 @@ def test_gen_atomic_symbol_multiple_outputs(): p = mx.sym.Variable('param') h0 = mx.sym.Variable('h0') h1 = mx.sym.Variable('h1') - s = mx.sym.RNN(data, p, h0, h1, state_size=10, num_layers=2, + s = mx.sym.RNN(data, p, h0, h1, state_size=10, num_layers=2, bidirectional=True, state_outputs=True, mode='lstm') atomic_sym = s._gen_atomic_symbol() @@ -542,6 +542,21 @@ def get_net(): assert out_shapes[0] == (batch_size, num_hdidden) # output assert len(aux_shapes) == 0 +def test_infershape_happens_for_all_ops_in_graph(): + v = mx.sym.Variable('V') + s = mx.sym.transpose(v) + x = mx.sym.Variable('x') + s2 = x + v + s3 = s + s2 + with discard_stderr(): + try: + # This should throw an exception as you cannot add arrays + # with shapes [2,3] and [3,2] + e = s3.simple_bind(ctx=mx.cpu(), x=(2,3), grad_req='null') + except: + return + + assert False if __name__ == '__main__': import nose