@@ -130,17 +130,24 @@ RecursiveCompilabilityChecker::FindUncompilableNodes(
130130 return uncompilable_nodes;
131131}
132132
133- bool RecursiveCompilabilityChecker::HasXLAKernel (const Node& node) const {
133+ bool RecursiveCompilabilityChecker::HasXLAKernel (
134+ const Node& node, string* uncompilable_reason) const {
134135 // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
135136 // is really a kind of function call and will be handled by
136137 // IsCompilableCall().
137- if (node.type_string () == " SymbolicGradient" ) return false ;
138+ if (node.type_string () == " SymbolicGradient" ) {
139+ *uncompilable_reason =
140+ " SymbolicGradient should be handled by IsCompilableCall()." ;
141+ return false ;
142+ }
138143 if (node.type_string () == " Const" ) {
139144 // Skip Const op with type DT_STRING, since XLA doesn't support it, but the
140145 // registered Const KernelDef says that it does, to support no-op Assert for
141146 // tfcompile.
142147 const AttrValue* attr = node.attrs ().Find (" dtype" );
143148 if (attr != nullptr && attr->type () == DT_STRING) {
149+ *uncompilable_reason =
150+ " Const op with type DT_STRING is not supported by XLA." ;
144151 return false ;
145152 }
146153 }
@@ -150,10 +157,16 @@ bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) const {
150157 // such nodes out of XLA clusters.
151158 if (HasForwardedRefInput (node)) {
152159 VLOG (2 ) << " Rejecting " << node.name () << " : Identity with unsafe cast." ;
160+ *uncompilable_reason = " Identity with unsafe cast." ;
153161 return false ;
154162 }
155163
156- return FindKernelDef (jit_device_type_, node.def (), nullptr , nullptr ).ok ();
164+ Status s = FindKernelDef (jit_device_type_, node.def (), nullptr , nullptr );
165+ if (!s.ok ()) {
166+ *uncompilable_reason = s.error_message ();
167+ return false ;
168+ }
169+ return true ;
157170}
158171
159172// Tests whether 'if_node' is compilable. Every operator in the then_branch and
@@ -336,16 +349,17 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
336349 return false ;
337350 }
338351
352+ string uncompilable_reason;
339353 if (IsFunctionCall (*lib_runtime->GetFunctionLibraryDefinition (), node)) {
340354 if (!IsCompilableCall (node.def (), lib_runtime, stack_trace,
341355 encapsulating_function, uncompilable_nodes)) {
342356 LogNotCompilable (node, " unsupported function" );
343357 return false ;
344358 }
345- } else if (!HasXLAKernel (node)) {
346- absl::string_view uncompilable_reason = " unsupported op " ;
347- MaybeMarkUncompilableNode ( uncompilable_reason, *stack_trace,
348- encapsulating_function, uncompilable_nodes);
359+ } else if (!HasXLAKernel (node, &uncompilable_reason )) {
360+ MaybeMarkUncompilableNode (
361+ absl::StrCat ( " unsupported op: " , uncompilable_reason) , *stack_trace,
362+ encapsulating_function, uncompilable_nodes);
349363 LogNotCompilable (node, uncompilable_reason);
350364 return false ;
351365 }
0 commit comments