diff --git a/src/super_gradients/common/crash_handler/crash_tips.py b/src/super_gradients/common/crash_handler/crash_tips.py index c55359f216..b4b2c59333 100644 --- a/src/super_gradients/common/crash_handler/crash_tips.py +++ b/src/super_gradients/common/crash_handler/crash_tips.py @@ -2,6 +2,8 @@ from typing import Union, Tuple, List, Type from types import TracebackType +import omegaconf + from super_gradients.common.crash_handler.utils import indent_string, fmt_txt, json_str_to_dict from super_gradients.common.abstractions.abstract_logger import get_logger @@ -197,6 +199,28 @@ def _get_tips(cls, exc_type: type, exc_value: Exception, exc_traceback: Tracebac return [tip] +class InterpolationKeyErrorTip(CrashTip): + @classmethod + def is_relevant(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType): + expected_str = "Interpolation key " + return isinstance(exc_value, omegaconf.errors.InterpolationKeyError) and expected_str in str(exc_value) + + @classmethod + def _get_tips(cls, exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> List[str]: + variable = re.search("'(.*?)'", str(exc_value)).group(1) + tip = ( + f"It looks like you encountered an error related to interpolation of the variable '{variable}'.\n" + "It's possible that this error is caused by not using the full path of the variable in your subfolder configuration.\n" + f"Please make sure that you are referring to the variable using the " + f"{fmt_txt('full path starting from the main configuration file', color='green')}.\n" + f"Try to replace '{fmt_txt(f'${{{variable}}}', color='red')}' with '{fmt_txt(f'${{full.path.to.{variable}}}', color='green')}', \n" + f" where 'full.path.to' is the actual path to reach '{variable}', starting from the root configuration file.\n" + f"Example: '{fmt_txt('${dataset_params.train_dataloader_params.batch_size}', color='green')}' " + f"instead of '{fmt_txt('${train_dataloader_params.batch_size}', color='red')}'.\n" + ) + return [tip] + + def get_relevant_crash_tip_message(exc_type: type, exc_value: Exception, exc_traceback: TracebackType) -> Union[None, str]: """Get a CrashTip class if relevant for input exception""" for crash_tip in CrashTip.get_sub_classes(): diff --git a/tests/unit_tests/crash_tips_test.py b/tests/unit_tests/crash_tips_test.py index d5541e0324..e2bc8b0683 100644 --- a/tests/unit_tests/crash_tips_test.py +++ b/tests/unit_tests/crash_tips_test.py @@ -2,6 +2,7 @@ import unittest import dataclasses from typing import Type +import omegaconf from super_gradients.common.crash_handler.crash_tips import ( get_relevant_crash_tip_message, CrashTip, @@ -9,6 +10,7 @@ RecipeFactoryFormatTip, DDPNotInitializedTip, WrongHydraVersionTip, + InterpolationKeyErrorTip, ) @@ -46,6 +48,10 @@ def setUp(self) -> None: exc_value=TypeError("__init__() got an unexpected keyword argument 'version_base'"), expected_crash_tip=WrongHydraVersionTip, ), + DocumentedException( + exc_value=omegaconf.errors.InterpolationKeyError("omegaconf.errors.InterpolationKeyError: Interpolation key 'x' not found"), + expected_crash_tip=InterpolationKeyErrorTip, + ), ] def test_found_exceptions(self):