@@ -126,7 +126,10 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule:
126126# (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall
127127# on pass that failed accuracy check.
128128def validate_inference (
129- rtol = None , atol = None , device = torch .device (torch .cuda .current_device ())
129+ rtol = None ,
130+ atol = None ,
131+ device = torch .device (torch .cuda .current_device ()),
132+ suppress_accuracy_check_failure = True ,
130133):
131134 def _validate_inference (pass_ : PassFunc ) -> PassFunc :
132135 """
@@ -141,48 +144,51 @@ def pass_with_validation(
141144 * args ,
142145 ** kwargs ,
143146 ) -> fx .GraphModule :
144- input_tensors = extract_example_tensors_from_input (input , device )
145- res0 = module (* input_tensors )
146- processed_module = pass_ (module , input , * args , ** kwargs )
147- res1 = processed_module (* input_tensors )
148- tensor_res_0 = _collect_tensors (res0 )
149- tensor_res_1 = _collect_tensors (res1 )
150- relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE
151-
152- for kk , (x , y ) in enumerate (zip (tensor_res_0 , tensor_res_1 )):
153- kwargs2 = {"equal_nan" : True }
154- if rtol :
155- kwargs2 ["rtol" ] = rtol
156- if atol :
157- kwargs2 ["atol" ] = atol
158- kwargs2 [
159- "msg"
160- ] = (
161- lambda msg : f"Pass { pass_ } failed correctness check due at output { kk } :\n { msg } "
162- )
163- # If tensors are on different devices, make sure to compare
164- # their copies that are on the same device.
165- if x .get_device () != y .get_device ():
166- x = x .cpu ()
167- y = y .cpu ()
168- try :
169- torch .testing .assert_close (x , y , ** kwargs2 )
170- except Exception as e :
171- if relax_accuracy_check_failure :
172- _LOGGER .error (f"{ e } " )
173- kwargs2 ["rtol" ] *= FINAL_CHECK_RTOL_MULTIPLIER
174- kwargs2 ["atol" ] *= FINAL_CHECK_ATOL_MULTIPLIER
175- new_atol = kwargs2 ["atol" ]
176- new_rtol = kwargs2 ["rtol" ]
177- _LOGGER .info (
178- f"Do a sanity check to see whether things are completely wrong with { new_atol = } , { new_rtol = } "
179- )
147+ if suppress_accuracy_check_failure :
148+ return pass_ (module , input , * args , ** kwargs )
149+ else :
150+ input_tensors = extract_example_tensors_from_input (input , device )
151+ res0 = module (* input_tensors )
152+ processed_module = pass_ (module , input , * args , ** kwargs )
153+ res1 = processed_module (* input_tensors )
154+ tensor_res_0 = _collect_tensors (res0 )
155+ tensor_res_1 = _collect_tensors (res1 )
156+ relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE
157+
158+ for kk , (x , y ) in enumerate (zip (tensor_res_0 , tensor_res_1 )):
159+ kwargs2 = {"equal_nan" : True }
160+ if rtol :
161+ kwargs2 ["rtol" ] = rtol
162+ if atol :
163+ kwargs2 ["atol" ] = atol
164+ kwargs2 [
165+ "msg"
166+ ] = (
167+ lambda msg : f"Pass { pass_ } failed correctness check due at output { kk } :\n { msg } "
168+ )
169+ # If tensors are on different devices, make sure to compare
170+ # their copies that are on the same device.
171+ if x .get_device () != y .get_device ():
172+ x = x .cpu ()
173+ y = y .cpu ()
174+ try :
180175 torch .testing .assert_close (x , y , ** kwargs2 )
181- return processed_module
182- else :
183- raise e
184-
185- return processed_module
176+ except Exception as e :
177+ if relax_accuracy_check_failure :
178+ _LOGGER .error (f"{ e } " )
179+ kwargs2 ["rtol" ] *= FINAL_CHECK_RTOL_MULTIPLIER
180+ kwargs2 ["atol" ] *= FINAL_CHECK_ATOL_MULTIPLIER
181+ new_atol = kwargs2 ["atol" ]
182+ new_rtol = kwargs2 ["rtol" ]
183+ _LOGGER .info (
184+ f"Do a sanity check to see whether things are completely wrong with { new_atol = } , { new_rtol = } "
185+ )
186+ torch .testing .assert_close (x , y , ** kwargs2 )
187+ return processed_module
188+ else :
189+ raise e
190+
191+ return processed_module
186192
187193 return pass_with_validation
188194
0 commit comments