diff --git a/conftest.py b/conftest.py index 196d98716239..dd63a629d2c5 100644 --- a/conftest.py +++ b/conftest.py @@ -16,6 +16,7 @@ # by pytest before any tests are run import doctest +import os import sys import warnings from os.path import abspath, dirname, join @@ -27,6 +28,7 @@ HfDoctestModule, HfDocTestParser, is_torch_available, + patch_testing_methods_to_collect_info, patch_torch_compile_force_graph, ) @@ -145,3 +147,8 @@ def check_output(self, want, got, optionflags): # patch `torch.compile`: if `TORCH_COMPILE_FORCE_FULLGRAPH=1` (or values considered as true, e.g. yes, y, etc.), # the patched version will always run with `fullgraph=True`. patch_torch_compile_force_graph() + + + +if os.environ.get("PATCH_TESTING_METHODS_TO_COLLECT_OUTPUTS", "").lower() in ("yes", "true", "on", "y", "1"): + patch_testing_methods_to_collect_info() diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index afc0c3e6d794..15b32c5fe45c 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ast import collections import contextlib import copy @@ -31,6 +32,7 @@ import tempfile import threading import time +import traceback import types import unittest from collections import UserDict, defaultdict @@ -3433,6 +3435,384 @@ def patched(*args, **kwargs): torch.compile = patched +def _get_test_info(): + """ + Collect some information about the current test. + + For example, test full name, line number, stack, traceback, etc. + """ + + full_test_name = os.environ.get("PYTEST_CURRENT_TEST", "").split(" ")[0] + test_file, test_class, test_name = full_test_name.split("::") + + # from the most recent frame to the top frame + stack_from_inspect = inspect.stack() + # but visit from the top frame to the most recent frame + + test_frame, test_obj, test_method = None, None, None + for frame in reversed(stack_from_inspect): + if test_file in str(frame).replace(r"\\", "/"): + if test_name == frame.frame.f_locals["self"]._testMethodName: + test_frame = frame + # The test instance + test_obj = frame.frame.f_locals["self"] + test_method = getattr(test_obj, test_name) + break + + if test_frame is not None: + line_number = test_frame.lineno + + # most inner (recent) to most outer () frames + captured_frames = [] + to_capture = False + # up to the test method being called + for frame in reversed(stack_from_inspect): + if test_file in str(frame).replace(r"\\", "/"): + if "self" in frame.frame.f_locals and test_name == frame.frame.f_locals["self"]._testMethodName: + to_capture = True + elif "patched" in frame.frame.f_code.co_name: + to_capture = False + break + if to_capture: + captured_frames.append(frame) + + tb_next = None + for frame_info in reversed(captured_frames): + tb = types.TracebackType(tb_next, frame_info.frame, frame_info.frame.f_lasti, frame_info.frame.f_lineno) + tb_next = tb + test_traceback = tb + + stack = traceback.extract_stack() + + # The frame that calls this patched method (it may not be the test method) + # -1: `_get_test_info`; -2: `patched_xxx`; -3: the caller to `patched_xxx` + caller_frame = stack[-3] + caller_path = os.path.relpath(caller_frame.filename) + caller_lineno = caller_frame.lineno + + test_lineno = line_number + + # Get the code context in the test function/method. + from _pytest._code.source import Source + + with open(test_file) as fp: + s = fp.read() + source = Source(s) + test_code_context = "\n".join(source.getstatement(test_lineno - 1).lines) + + # Get the code context in the caller (to the patched function/method). + with open(caller_path) as fp: + s = fp.read() + source = Source(s) + caller_code_context = "\n".join(source.getstatement(caller_lineno - 1).lines) + + test_info = ( + f"test:\n\n{full_test_name}\n\n{'-' * 80}\n\ntest context: {test_file}:{test_lineno}\n\n{test_code_context}" + ) + test_info = f"{test_info}\n\n{'-' * 80}\n\ncaller context: {caller_path}:{caller_lineno}\n\n{caller_code_context}" + + return ( + full_test_name, + test_file, + test_lineno, + test_obj, + test_method, + test_frame, + test_traceback, + test_code_context, + caller_path, + caller_lineno, + caller_code_context, + test_info, + ) + + +def _get_call_arguments(code_context): + """ + Analyze the positional and keyword arguments in a call expression. + + This will extract the expressions of the positional and kwyword arguments, and associate them to the positions and + the keyword arugment names. + """ + + def get_argument_name(node): + """Extract the name/expression from an AST node""" + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Attribute): + return ast.unparse(node) + elif isinstance(node, ast.Constant): + return repr(node.value) + else: + return ast.unparse(node) + + indent = len(code_context) - len(code_context.lstrip()) + code_context = code_context.replace(" " * indent, "") + + try: + # Parse the line + tree = ast.parse(code_context, mode="eval") + + assert isinstance(tree.body, ast.Call) + call_node = tree.body + + if call_node: + result = { + "positional_args": [], + "keyword_args": {}, + "starargs": None, # *args + "kwargs": None, # **kwargs + } + + # Extract positional arguments + for arg in call_node.args: + arg_name = get_argument_name(arg) + result["positional_args"].append(arg_name) + + # Extract keyword arguments + for keyword in call_node.keywords: + if keyword.arg is None: + # This is **kwargs + result["kwargs"] = get_argument_name(keyword.value) + else: + # Regular keyword argument + arg_name = get_argument_name(keyword.value) + result["keyword_args"][keyword.arg] = arg_name + + return result + + except (SyntaxError, AttributeError) as e: + print(f"Error parsing: {e}") + + return None + + +def _prepare_debugging_info(test_info, info): + """Combine the information about the test and the call information to a patched function/method within it.""" + + info = f"{test_info}\n\n{info}" + p = os.path.join(os.environ.get("_PATCHED_TESTING_METHODS_OUTPUT_DIR", ""), "captured_info.txt") + # TODO (ydshieh): This is not safe when we use pytest-xdist with more than 1 worker. + with open(p, "a") as fp: + fp.write(f"{info}\n\n{'=' * 120}\n\n") + + return info + + +def _patched_tearDown(self, *args, **kwargs): + """Used to report a test that has failures captured and handled by patched functions/methods (without re-raise). + + The patched functions/methods refer to the `patched` defined in `_patch_with_call_info`, which is applied to + `torch.testing.assert_close` and `unittest.case.TestCase.assertEqual`. + + The objective is to avoid a failure being silence after being processed. + + If there is any failure that is not handled by the patched functions/methods, we add custom error message for them + along with the usual pytest failure report. + """ + + # Check for regular failures before clearing: + # when `_patched_tearDown` is called, the current test fails due to an assertion error given by a method being + # patched by `_patch_with_call_info`. The patched method catches such an error and continue running the remaining + # statements within the test. If the test fails with another error not handled by the patched methods, we don't let + # pytest to fail and report it but the original failure (the first one that was processed) instead. + # We still record those failures not handled by the patched methods, and add custom messages along with the usual + # pytest failure report. + regular_failures_info = [] + if hasattr(self, "_outcome") and self._outcome.errors: + for error_entry in self._outcome.errors: + test_instance, (exc_type, exc_obj, exc_tb) = error_entry + # breakpoint() + regular_failures_info.append( + { + "message": f"{str(exc_obj)}\n\n", + "type": exc_type.__name__, + "file": "test_modeling_vit.py", + "line": 237, # get_deepest_frame_line(exc_tb) # Your helper function + } + ) + + # Clear the regular failure (i.e. that is not from any of our patched assertion methods) from pytest's records. + self._outcome.errors.clear() + + # reset back to the original tearDown method, so `_patched_tearDown` won't be run by the subsequent tests if they + # have only test failures that are not handle by the patched methods (or no test failure at all). + orig_tearDown = _patched_tearDown.orig_tearDown + type(self).tearDown = orig_tearDown + + # Call the original tearDown + orig_tearDown(self, *args, **kwargs) + + # Get the failure + test_method = getattr(self, self._testMethodName) + captured_failures = test_method.__func__.captured_failures[id(test_method)] + + # TODO: How could we show several exceptions in a sinigle test on the terminal? (Maybe not a good idea) + captured_exceptions = captured_failures[0]["exception"] + captured_traceback = captured_failures[0]["traceback"] + # Show the cpatured information on the terminal. + capturued_info = [x["info"] for x in captured_failures] + capturued_info_str = f"\n\n{'=' * 80}\n\n".join(capturued_info) + + # Enhance the exception message if there were suppressed failures + if regular_failures_info: + enhanced_message = f"""{str(captured_exceptions)} + +{"=" * 80} +Handled Failures: ({len(capturued_info)} handled): +{"-" * 80}\n +{capturued_info_str} + +{"=" * 80} +Unhandled Failures: ({len(regular_failures_info)} unhandled): +{"-" * 80}\n +{", ".join(f"{info['type']}: {info['message']}{info['file']}:{info['line']}" for info in regular_failures_info)} + +{"-" * 80} +Note: This failure occurred after other failures analyzed by the patched assertion methods. +To see the full details, temporarily disable assertion patching. +{"=" * 80}""" + + # Create new exception with enhanced message + enhanced_exception = type(captured_exceptions)(enhanced_message) + enhanced_exception.__cause__ = captured_exceptions.__cause__ + enhanced_exception.__context__ = captured_exceptions.__context__ + + # Raise with your existing traceback reconstruction + captured_exceptions = enhanced_exception + + # clean up the recorded status + del test_method.__func__.captured_failures + + raise captured_exceptions.with_traceback(captured_traceback) + + +def _patch_with_call_info(module_or_class, attr_name, _parse_call_info_func, target_args): + """ + Patch a callerable `attr_name` of a module or class `module_or_class`. + + This will allow us to collect the call information, e.g. the argument names and values, also the literal expressions + passed as the arguments. + """ + orig_method = getattr(module_or_class, attr_name) + if not callable(orig_method): + return + + def patched(*args, **kwargs): + # If the target callable is not called within a test, simply call it without modification. + if not os.environ.get("PYTEST_CURRENT_TEST", ""): + return orig_method(*args, **kwargs) + + try: + orig_method(*args, **kwargs) + except AssertionError as e: + captured_exception = e + # captured_traceback = e.__traceback__ + ( + full_test_name, + test_file, + test_lineno, + test_obj, + test_method, + test_frame, + test_traceback, + test_code_context, + caller_path, + caller_lineno, + caller_code_context, + test_info, + ) = _get_test_info() + test_info = f"{test_info}\n\n{'-' * 80}\n\npatched method: {orig_method.__module__}.{orig_method.__name__}" + call_argument_expressions = _get_call_arguments(caller_code_context) + + # This is specific + info = _parse_call_info_func(orig_method, args, kwargs, call_argument_expressions, target_args) + info = _prepare_debugging_info(test_info, info) + + # Save this, so we can raise at the end of the current test + captured_failure = { + "result": "failed", + "exception": captured_exception, + "traceback": test_traceback, + "info": info, + } + + # Record the failure status and its information, so we can raise it later. + # We are modifying the (unbound) function at class level: not its logic but only adding a new extra + # attribute. + if getattr(test_method.__func__, "captured_failures", None) is None: + test_method.__func__.captured_failures = {} + if id(test_method) not in test_method.__func__.captured_failures: + test_method.__func__.captured_failures[id(test_method)] = [] + test_method.__func__.captured_failures[id(test_method)].append(captured_failure) + + # This modifies the `tearDown` which will be called after every tests, but we reset it back inside + # `_patched_tearDown`. + if not hasattr(type(test_obj).tearDown, "orig_tearDown"): + orig_tearDown = type(test_obj).tearDown + _patched_tearDown.orig_tearDown = orig_tearDown + type(test_obj).tearDown = _patched_tearDown + + setattr(module_or_class, attr_name, patched) + + +def _parse_call_info(func, args, kwargs, call_argument_expressions, target_args): + """ + Prepare a string containing the call info to `func`, e.g. argument names/values/expressions. + """ + signature = inspect.signature(func) + signature_names = [param.name for param_name, param in signature.parameters.items()] + + # called as `self.method_name()` or `xxx.method_name()`. + if len(args) == len(call_argument_expressions["positional_args"]) + 1: + # We simply add "self" as the expression despite it might not be the actual argument name. + # (This part is very unlikely what a user would be interest to know) + call_argument_expressions["positional_args"] = ["self"] + call_argument_expressions["positional_args"] + + param_position_mapping = {param_name: idx for idx, param_name in enumerate(signature_names)} + + arg_info = {} + for arg_name in target_args: + if arg_name in kwargs: + arg_value = kwargs[arg_name] + arg_expr = call_argument_expressions["keyword_args"][arg_name] + else: + arg_pos = param_position_mapping[arg_name] + arg_value = args[arg_pos] + arg_expr = call_argument_expressions["positional_args"][arg_pos] + + arg_value_str = _format_py_obj(arg_value) + arg_info[arg_name] = {"arg_expr": arg_expr, "arg_value_str": arg_value_str} + + info = "" + for arg_name in arg_info: + arg_expr, arg_value_str = arg_info[arg_name]["arg_expr"], arg_info[arg_name]["arg_value_str"] + info += f"{'-' * 80}\n\nargument name: `{arg_name}`\nargument expression: `{arg_expr}`\n\nargument value:\n\n{arg_value_str}\n\n" + + # remove the trailing \n\n + info = info[:-2] + + return info + + +def patch_testing_methods_to_collect_info(): + """ + Patch some methods (`torch.testing.assert_close`, `unittest.case.TestCase.assertEqual`, etc). + + This will allow us to collect the call information, e.g. the argument names and values, also the literal expressions + passed as the arguments. + """ + p = os.path.join(os.environ.get("_PATCHED_TESTING_METHODS_OUTPUT_DIR", ""), "captured_info.txt") + Path(p).unlink(missing_ok=True) + + if is_torch_available(): + import torch + + _patch_with_call_info(torch.testing, "assert_close", _parse_call_info, target_args=("actual", "expected")) + + _patch_with_call_info(unittest.case.TestCase, "assertEqual", _parse_call_info, target_args=("first", "second")) + + def torchrun(script: str, nproc_per_node: int, is_torchrun: bool = True, env: Optional[dict] = None): """Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary.""" with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp: @@ -3451,3 +3831,313 @@ def torchrun(script: str, nproc_per_node: int, is_torchrun: bool = True, env: Op _ = subprocess.run(cmd, capture_output=True, env=env, text=True, check=True) except subprocess.CalledProcessError as e: raise Exception(f"The following error was captured: {e.stderr}") + + +def _format_tensor(t, indent_level=0, sci_mode=None): + """Format torch's tensor in a pretty way to be shown 👀 in the test report.""" + + # `torch.testing.assert_close` could accept python int/float numbers. + if not isinstance(t, torch.Tensor): + t = torch.tensor(t) + + # Simply make the processing below simpler (not to hande both case) + is_scalar = False + if t.ndim == 0: + t = torch.tensor([t]) + is_scalar = True + + # For scalar or one-dimensional tensor, keep it as one-line. If there is only one element along any dimension except + # the last one, we also keep it as one-line. + if t.ndim <= 1 or set(t.shape[0:-1]) == {1}: + # Use `detach` to remove `grad_fn=<...>`, and use `to("cpu")` to remove `device='...'` + t = t.detach().to("cpu") + + # We work directly with the string representation instead the tensor itself + t_str = str(t) + + # remove `tensor( ... )` so keep only the content + t_str = t_str.replace("tensor(", "").replace(")", "") + + # Sometimes there are extra spaces between `[` and the first digit of the first value (for alignment). + # For example `[[ 0.06, -0.51], [-0.76, -0.49]]`. It may have multiple consecutive spaces. + # Let's remove such extra spaces. + while "[ " in t_str: + t_str = t_str.replace("[ ", "[") + + # Put everything in a single line. We replace `\n` by a space ` ` so we still keep `,\n` as `, `. + t_str = t_str.replace("\n", " ") + + # Remove repeated spaces (introduced by the previous step) + while " " in t_str: + t_str = t_str.replace(" ", " ") + + # remove leading `[` and `]` for scalar tensor + if is_scalar: + t_str = t_str[1:-1] + + t_str = " " * 4 * indent_level + t_str + + return t_str + + # Otherwise, we separte the representations of every elements along an outer dimension by new lines (after a `,`). + # The representatioin each element is obtained by calling this function recursively with corrent `indent_level`. + else: + t_str = str(t) + + # (For the recursive calls should receive this value) + if sci_mode is None: + sci_mode = "e+" in t_str or "e-" in t_str + + # Use the original content to determine the scientific mode to use. This is required as the representation of + # t[index] (computed below) maybe have different format regarding scientific notation. + torch.set_printoptions(sci_mode=sci_mode) + + t_str = " " * 4 * indent_level + "[\n" + # Keep the ending `,` for all outer dimensions whose representations are not put in one-line, even if there is + # only one element along that dimension. + t_str += ",\n".join(_format_tensor(x, indent_level=indent_level + 1, sci_mode=sci_mode) for x in t) + t_str += ",\n" + " " * 4 * indent_level + "]" + + torch.set_printoptions(sci_mode=None) + + return t_str + + +def _quote_string(s): + """Given a string `s`, return a python literal expression that give `s` when it is used in a python source code. + + For example, if `s` is the string `abc`, the return value is `"abc"`. + + We choice double quotes over single quote despite `str(s)` would give `'abc'` instead of `"abc"`. + """ + has_single_quote = "'" in s + has_double_quote = '"' in s + + if has_single_quote and has_double_quote: + # replace any double quote by the raw string r'\"'. + s = s.replace('"', r"\"") + return f'"{s}"' + elif has_single_quote: + return f'"{s}"' + elif has_double_quote: + return f"'{s}'" + else: + return f'"{s}"' + + +def _format_py_obj(obj, indent=0, mode="", cache=None, prefix=""): + """Format python objects of basic built-in type in a pretty way so we could copy-past them to code editor easily. + + Currently, this support int, float, str, list, tuple, and dict. + + It also works with `torch.Tensor` via calling `format_tesnor`. + """ + + if cache is None: + cache = {} + else: + if (id(obj), indent, mode, prefix) in cache: + return cache[(id(obj), indent, mode, prefix)] + + # special format method for `torch.Tensor` + if str(obj.__class__) == "": + return _format_tensor(obj) + + elif obj.__class__.__name__ == "str": + quoted_string = _quote_string(obj) + # we don't want the newline being interpreted + quoted_string = quoted_string.replace("\n", r"\n") + output = quoted_string + + elif obj.__class__.__name__ in ["int", "float"]: + # for float like `1/3`, we will get `0.3333333333333333` + output = str(obj) + + elif obj.__class__.__name__ in ["list", "tuple", "dict"]: + parenthesis = { + "list": "[]", + "tuple": "()", + "dict": "{}", + } + p1, p2 = parenthesis[obj.__class__.__name__] + + elements_without_indent = [] + if isinstance(obj, dict): + for idx, (k, v) in enumerate(obj.items()): + last_element = idx == len(obj) - 1 + ok = _format_py_obj(k, indent=indent + 1, mode="one-line", cache=cache) + ov = _format_py_obj( + v, + indent=indent + 1, + mode=mode, + cache=cache, + prefix=ok.lstrip() + ": " + "," if not last_element else "", + ) + # Each element could be multiple-line, but the indent of its first line is removed + elements_without_indent.append(f"{ok.lstrip()}: {ov.lstrip()}") + + else: + for idx, x in enumerate(obj): + last_element = idx == len(obj) - 1 + o = _format_py_obj( + x, indent=indent + 1, mode=mode, cache=cache, prefix="," if not last_element else "" + ) + # Each element could be multiple-line, but the indent of its first line is removed + elements_without_indent.append(o.lstrip()) + + groups = [] + buf = [] + for idx, x in enumerate(elements_without_indent): + buf.append(x) + + x_expanded = "\n" in buf[-1] + not_last_element = idx != len(elements_without_indent) - 1 + # if `x` should be separated from subsequent elements + should_finalize_x = x_expanded or len(f"{' ' * (4 * (indent + 1))}") + len( + ", ".join(buf[-1:]) + ) > 120 - int(not_last_element) + + # if `buf[:-1]` (i.e. without `x`) should be combined together (into one line) + should_finalize_buf = x_expanded + + # the recursive call returns single line, so we can use it to determine if we can fit the width limit + if not should_finalize_buf: + buf_not_fit_into_one_line = len(f"{' ' * (4 * (indent + 1))}") + len(", ".join(buf)) > 120 - int( + not_last_element + ) + should_finalize_buf = buf_not_fit_into_one_line + + # any element of iterable type need to be on its own line + if (type(obj[idx]) if type(obj) is not dict else type(list(obj.values())[idx])) in [list, tuple, dict]: + should_finalize_x = True + should_finalize_buf = True + + # any type change --> need to be added after a new line + prev_type = None + current_type = type(obj[idx]) if type(obj) is not dict else type(list(obj.values())[idx]) + if len(buf) > 1: + prev_type = type(obj[idx - 1]) if type(obj) is not dict else type(list(obj.values())[idx - 1]) + type_changed = current_type != prev_type + if type_changed: + should_finalize_buf = True + + # all elements in the buf are string --> don't finalize the buf by width limit + if prev_type is None or (prev_type is str and current_type is str): + should_finalize_buf = False + + # collect as many elements of string type as possible (without width limit). + # These will be examined as a whole (if not fit into the width, each element would be in its own line) + if current_type is str: + should_finalize_x = False + # `len(buf) == 1` or `obj[idx-1]` is a string + if prev_type in [None, str]: + should_finalize_buf = False + + if should_finalize_buf: + orig_buf_len = len(buf) + + if orig_buf_len > 1: + not_fit_into_one_line = None + + # all elements in `obj` that give `buf[:-1]` are string. + if prev_type is str: + # `-1` at the end: because buf[-2] is not the last element + not_fit_into_one_line = len(f"{' ' * (4 * (indent + 1))}") + len(", ".join(buf[:-1])) > 120 - 1 + + if not_fit_into_one_line: + for x in buf[:-1]: + groups.append([x]) + else: + groups.append(buf[:-1]) + + buf = buf[-1:] + + if should_finalize_x: + groups.append(buf) + buf = [] + + # The last buf + if len(buf) > 0: + not_fit_into_one_line = None + if current_type is str: + # no `-1` at the end: because buf[-1] is the last element + not_fit_into_one_line = len(f"{' ' * (4 * (indent + 1))}") + len(", ".join(buf)) > 120 + + if not_fit_into_one_line: + for x in buf: + groups.append([x]) + else: + groups.append(buf) + + output = f"{' ' * 4 * indent}{p1}\n" + element_strings = [f"{' ' * (4 * (indent + 1))}" + ", ".join(buf) for buf in groups] + output += ",\n".join(element_strings) + output += f"\n{' ' * 4 * indent}{p2}" + + # if all elements are in one-line + no_new_line_in_elements = all("\n" not in x for x in element_strings) + # if yes, we can form a one-line representation of `obj` + could_use_one_line = no_new_line_in_elements + + # if mode == "one-line", this function always returns one-line representation, so `no_new_line_in_elements` + # will be `True`. + if could_use_one_line: + one_line_form = ", ".join([x.lstrip() for x in element_strings]) + one_line_form = f"{p1}{one_line_form}{p2}" + + if mode == "one-line": + return output + + # check with the width limit + could_use_one_line = len(f"{' ' * 4 * indent}") + len(prefix) + len(one_line_form) <= 120 + + # extra conditions for returning one-line representation + def use_one_line_repr(obj): + # interable types + if type(obj) in (list, tuple, dict): + # get all types + element_types = [] + if type(obj) is dict: + element_types.extend(type(x) for x in obj.values()) + elif type(obj) in [list, tuple]: + element_types.extend(type(x) for x in obj) + + # At least one element is of iterable type + if any(x in (list, tuple, dict) for x in element_types): + # If `obj` has more than one element and at least one of them is iterable --> no one line repr. + if len(obj) > 1: + return False + + # only one element that is iterable, but not the same type as `obj` --> no one line repr. + if type(obj) is not type(obj[0]): + return False + + # one-line repr. if possible, without width limit + return no_new_line_in_elements + + # all elements are of simple types, but more than one type --> no one line repr. + if len(set(element_types)) > 1: + return False + + # all elements are of the same simple type + if element_types[0] in [int, float]: + # one-line repr. without width limit + return no_new_line_in_elements + elif element_types[0] in [str]: + if len(obj) == 1: + # one single string element --> one-line repr. without width limit + return no_new_line_in_elements + else: + # multiple string elements --> one-line repr. if fit into width limit + return could_use_one_line + + # simple types (int, flat, string) + return True + + # width condition combined with specific mode conditions + if use_one_line_repr(obj): + output = f"{' ' * 4 * indent}{one_line_form}" + + cache[(id(obj), indent, mode, prefix)] = output + + return output