@@ -39,7 +39,7 @@ void NotateModuleForFallback(
3939 if (n->kind () == torch::jit::prim::GetAttr) {
4040 auto out_type = unmangle_cls_name (c10::toString (n->output (0 )->type ()));
4141 if (forced_fallback_modules.find (out_type) != forced_fallback_modules.end ()) {
42- LOG_DEBUG (
42+ LOG_GRAPH (
4343 " Notating module for fallback: " << n->s (c10::attr::name) << " (" << out_type << " ) [owner: " << mod_name
4444 << " (" << cls_name << " )]" );
4545 auto uses = n->output (0 )->uses ();
@@ -58,11 +58,32 @@ void NotateModuleForFallback(
5858 }
5959
6060 if (changed_mod) {
61- LOG_DEBUG (" Notated graph: " << *g);
61+ LOG_GRAPH (" Notated graph: " << *g);
6262 }
6363
64- for (const auto sub_mod : mod.named_children ()) {
65- NotateModuleForFallback (sub_mod.value , sub_mod.name , method_name, forced_fallback_modules);
64+ if (mod.named_children ().size () > 0 ) {
65+ for (const auto n : nodes) {
66+ std::string sub_method_name = " " ;
67+ if (n->kind () == torch::jit::prim::CallMethod) {
68+ sub_method_name = n->s (c10::Symbol::attr (" name" ));
69+ auto sub_mod_val = n->input (0 );
70+ auto sub_mod_src_n = sub_mod_val->node ();
71+ if (!sub_mod_src_n->hasAttributeS (" name" )) {
72+ LOG_GRAPH (" Node: " << util::node_info (sub_mod_src_n) << " manages a module with no name, skipping" );
73+ break ;
74+ }
75+ auto sub_mod_name = sub_mod_src_n->s (c10::Symbol::attr (" name" ));
76+ for (const auto sub_mod : mod.named_children ()) {
77+ // Theres probably a way to directly access the module we care about
78+ if (sub_mod.name == sub_mod_name) {
79+ LOG_GRAPH (
80+ " Looking at <module>.<method>() next: " << sub_mod_name << " ." << sub_method_name
81+ << " () (lowering.passes.NotateModuleForFallback)" );
82+ NotateModuleForFallback (sub_mod.value , sub_mod.name , sub_method_name, forced_fallback_modules);
83+ }
84+ }
85+ }
86+ }
6687 }
6788}
6889
@@ -74,23 +95,23 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_del
7495 auto n = *it;
7596 if (!mark.top () && n->kind () == torch::jit::prim::Enter && n->hasAttributeS (" compilation_edge" )) {
7697 if (n->s (c10::Symbol::attr (" compilation_edge" )) == " start" ) {
77- LOG_DEBUG (" Starting to mark new segmented block targeted for torch" );
98+ LOG_GRAPH (" Starting to mark new segmented block targeted for torch" );
7899 mark.push (true );
79100 if (delete_delims) {
80101 it.destroyCurrent ();
81102 }
82103 }
83104 } else if (mark.top () && n->kind () == torch::jit::prim::Enter && n->hasAttributeS (" compilation_edge" )) {
84105 if (n->s (c10::Symbol::attr (" compilation_edge" )) == " start" ) {
85- LOG_DEBUG (" Found the start of another segmented block targeted for torch while actively marking a block" );
106+ LOG_GRAPH (" Found the start of another segmented block targeted for torch while actively marking a block" );
86107 mark.push (true );
87108 if (delete_delims) {
88109 it.destroyCurrent ();
89110 }
90111 }
91112 } else if (mark.top () && n->kind () == torch::jit::prim::Exit && n->hasAttributeS (" compilation_edge" )) {
92113 if (n->s (c10::Symbol::attr (" compilation_edge" )) == " end" ) {
93- LOG_DEBUG (" Found the end of segmented block targeted for torch while actively marking a block" );
114+ LOG_GRAPH (" Found the end of segmented block targeted for torch while actively marking a block" );
94115 mark.pop ();
95116 if (delete_delims) {
96117 it.destroyCurrent ();
@@ -106,7 +127,7 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_del
106127 }
107128 }
108129
109- LOG_DEBUG (" After marking operations for torch fallback: " << *g);
130+ LOG_GRAPH (" After marking operations for torch fallback: " << *g);
110131}
111132
112133} // namespace passes
0 commit comments