From 8b30ccd9d7da2412366d5e57db6fdb7fd8cab3bc Mon Sep 17 00:00:00 2001 From: 0x45f Date: Wed, 1 Sep 2021 09:52:24 +0000 Subject: [PATCH] modify dy2stat error message in runtime and format error message --- .../fluid/dygraph/dygraph_to_static/error.py | 66 ++++++++++++++++--- .../unittests/dygraph_to_static/test_error.py | 31 ++++++++- 2 files changed, 87 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/error.py b/python/paddle/fluid/dygraph/dygraph_to_static/error.py index 66d3b58f4c2dd..ffcc8c95bbc81 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/error.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/error.py @@ -17,6 +17,7 @@ import sys import traceback import linecache +import re from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginInfo, global_origin_info_map @@ -106,22 +107,34 @@ def __init__(self, location, function_name): begin_lineno = max(1, self.location.lineno - int(SOURCE_CODE_RANGE / 2)) for i in range(begin_lineno, begin_lineno + SOURCE_CODE_RANGE): - line = linecache.getline(self.location.filepath, i) - line_lstrip = line.strip() + line = linecache.getline(self.location.filepath, i).rstrip('\n') + line_lstrip = line.lstrip() self.source_code.append(line_lstrip) - blank_count.append(len(line) - len(line_lstrip)) + if not line_lstrip: # empty line from source code + blank_count.append(-1) + else: + blank_count.append(len(line) - len(line_lstrip)) if i == self.location.lineno: hint_msg = '~' * len(self.source_code[-1]) + ' <--- HERE' self.source_code.append(hint_msg) blank_count.append(blank_count[-1]) linecache.clearcache() - - min_black_count = min(blank_count) + # remove top and bottom empty line in source code + while len(self.source_code) > 0 and not self.source_code[0]: + self.source_code.pop(0) + blank_count.pop(0) + while len(self.source_code) > 0 and not self.source_code[-1]: + self.source_code.pop(-1) + blank_count.pop(-1) + + min_black_count = min([i for i in blank_count if i >= 0]) for i in range(len(self.source_code)): - self.source_code[i] = ' ' * (blank_count[i] - min_black_count + - BLANK_COUNT_BEFORE_FILE_STR * 2 - ) + self.source_code[i] + # if source_code[i] is empty line between two code line, dont add blank + if self.source_code[i]: + self.source_code[i] = ' ' * (blank_count[i] - min_black_count + + BLANK_COUNT_BEFORE_FILE_STR * 2 + ) + self.source_code[i] def formated_message(self): msg = ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n'.format( @@ -212,6 +225,7 @@ def _simplify_error_value(self): 1. Need a more robust way because the code of start_trace may change. 2. Set the switch to determine whether to simplify error_value """ + assert self.in_runtime is True error_value_lines = str(self.error_value).split("\n") @@ -219,9 +233,43 @@ def _simplify_error_value(self): start_trace = "outputs = static_func(*inputs)" start_idx = error_value_lines_strip.index(start_trace) + error_value_lines = error_value_lines[start_idx + 1:] + error_value_lines_strip = error_value_lines_strip[start_idx + 1:] + + # use empty line to locate the bottom_error_message + empty_line_idx = error_value_lines_strip.index('') + bottom_error_message = error_value_lines[empty_line_idx + 1:] + + filepath = '' + error_from_user_code = [] + pattern = 'File "(?P.+)", line (?P.+), in (?P.+)' + for i in range(0, len(error_value_lines_strip), 2): + if error_value_lines_strip[i].startswith("File "): + re_result = re.search(pattern, error_value_lines_strip[i]) + tmp_filepath, lineno_str, function_name = re_result.groups() + code = error_value_lines_strip[i + 1] if i + 1 < len( + error_value_lines_strip) else '' + if i == 0: + filepath = tmp_filepath + if tmp_filepath == filepath: + error_from_user_code.append( + (tmp_filepath, int(lineno_str), function_name, code)) + + error_frame = [] + whether_source_range = True + for filepath, lineno, funcname, code in error_from_user_code[::-1]: + loc = Location(filepath, lineno) + if whether_source_range: + traceback_frame = TraceBackFrameRange(loc, funcname) + whether_source_range = False + else: + traceback_frame = TraceBackFrame(loc, funcname, code) + + error_frame.insert(0, traceback_frame.formated_message()) - error_value_str = '\n'.join(error_value_lines) + error_frame.extend(bottom_error_message) + error_value_str = '\n'.join(error_frame) self.error_value = self.error_type(error_value_str) def raise_new_exception(self): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py index aafb02870990d..6dd8c8e0766bf 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py @@ -98,6 +98,16 @@ def test_func(self): return +@paddle.jit.to_static +def func_error_in_runtime_with_empty_line(x): + x = fluid.dygraph.to_variable(x) + two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32") + + x = fluid.layers.reshape(x, shape=[1, two]) + + return x + + class TestFlags(unittest.TestCase): def setUp(self): self.reset_flags_to_default() @@ -293,7 +303,26 @@ def set_message(self): self.expected_message = \ [ 'File "{}", line 54, in func_error_in_runtime'.format(self.filepath), - 'x = fluid.layers.reshape(x, shape=[1, two])' + 'x = fluid.dygraph.to_variable(x)', + 'two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")', + 'x = fluid.layers.reshape(x, shape=[1, two])', + '<--- HERE', + 'return x' + ] + + +class TestErrorStaticLayerCallInRuntime2(TestErrorStaticLayerCallInRuntime): + def set_func(self): + self.func = func_error_in_runtime_with_empty_line + + def set_message(self): + self.expected_message = \ + [ + 'File "{}", line 106, in func_error_in_runtime_with_empty_line'.format(self.filepath), + 'two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")', + 'x = fluid.layers.reshape(x, shape=[1, two])', + '<--- HERE', + 'return x' ]