Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][B-35] Add type annotations for python/paddle/amp/debugging.py #66127

Merged
merged 5 commits into from
Jul 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 69 additions & 36 deletions python/paddle/amp/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import contextlib
import random
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
Sequence,
TypeVar,
)

import numpy as np
from typing_extensions import ParamSpec

import paddle
from paddle import _C_ops
from paddle.base import core

from ..framework import LayerHelper, in_dynamic_or_pir_mode

if TYPE_CHECKING:
from typing import Generator

from paddle import Tensor

_InputT = ParamSpec("_InputT")
_RetT = TypeVar("_RetT")
__all__ = [
"DebugMode",
"TensorCheckerConfig",
Expand Down Expand Up @@ -60,7 +76,9 @@ class DebugMode(Enum):
# DUMP_ALL = 5


def check_layer_numerics(func):
def check_layer_numerics(
func: Callable[_InputT, _RetT]
) -> Callable[_InputT, _RetT]:
"""
This decorator is used to check the numerical values of the layer's input and output data.

Expand Down Expand Up @@ -110,7 +128,7 @@ def check_layer_numerics(func):
>>> # RuntimeError: (PreconditionNotMet) There are NAN or INF (num_nan=0, num_inf=4, num_zero=0) in [device=gpu:0, op=divide, tensor=, dtype=fp32].
"""

def wrapper(self, *args, **kwargs):
def wrapper(self, *args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
if args:
# Set temp data and temp.gradient = False
start_data = args[0]
Expand All @@ -133,7 +151,7 @@ def wrapper(self, *args, **kwargs):
return wrapper


def set_checked_op_list(checked_op_list):
def set_checked_op_list(checked_op_list: Sequence[str] | None) -> None:
# check checked_op_list
if checked_op_list is not None:
if isinstance(checked_op_list, (list, tuple)):
Expand All @@ -143,7 +161,7 @@ def set_checked_op_list(checked_op_list):
raise ValueError("checked_op_list must be list or tuple")


def set_skipped_op_list(skipped_op_list):
def set_skipped_op_list(skipped_op_list: Sequence[str] | None) -> None:
# check skipped_op_list
if skipped_op_list is not None:
if isinstance(skipped_op_list, (list, tuple)):
Expand All @@ -162,13 +180,13 @@ class TensorCheckerConfig:

debug_mode(DebugMode, optional): A parameter that determines the type of debugging to be used. Default is DebugMode.CHECK_NAN_INF_AND_ABORT.

output_dir(string, optional): The path to store collected data. If this parameter is set to None, the data will be printed to the terminal. Default is None.
output_dir(string|None, optional): The path to store collected data. If this parameter is set to None, the data will be printed to the terminal. Default is None.

checked_op_list(list|tuple, optional): Specifies a list of operators that need to be checked during program execution, for example, checked_op_list=['elementwise_add', 'conv2d'], indicating that the output results of elementwise_add and conv2d should be checked for nan/inf during program execution. Default is None.
checked_op_list(list|tuple|None, optional): Specifies a list of operators that need to be checked during program execution, for example, checked_op_list=['elementwise_add', 'conv2d'], indicating that the output results of elementwise_add and conv2d should be checked for nan/inf during program execution. Default is None.

skipped_op_list(list|tuple, optional): Specifies a list of operators that do not need to be checked during program execution, for example, skipped_op_list=['elementwise_add', 'conv2d'], indicating that the output results of elementwise_add and conv2d should not be checked for nan/inf during program execution. None is None.
skipped_op_list(list|tuple|None, optional): Specifies a list of operators that do not need to be checked during program execution, for example, skipped_op_list=['elementwise_add', 'conv2d'], indicating that the output results of elementwise_add and conv2d should not be checked for nan/inf during program execution. None is None.

debug_step(list|tuple, optional): A list or tuple used primarily for nan/inf checking during model training. For example, debug_step=[1,5] indicates that nan/inf checking should only be performed on model training iterations 1 to 5. Default is None.
debug_step(list|tuple|None, optional): A list or tuple used primarily for nan/inf checking during model training. For example, debug_step=[1,5] indicates that nan/inf checking should only be performed on model training iterations 1 to 5. Default is None.

stack_height_limit(int, optional): An integer value specifying the maximum depth of the call stack. This feature supports printing the call stack at the error location. Currently, only enabling or disabling call stack printing is supported. If you want to print the corresponding C++ call stack when NaN is detected in GPU Kernel, set stack_height_limit to 1, otherwise set it to 0. Default is 1.

Expand Down Expand Up @@ -198,17 +216,29 @@ class TensorCheckerConfig:
"""

# For module debugging
current_step_id = 0
current_step_id: int = 0

enable: bool
debug_mode: DebugMode
output_dir: str | None
checked_op_list: Sequence[str] | None
skipped_op_list: Sequence[str] | None
debug_step: Sequence[int] | None
stack_height_limit: int
start_step: int | None
end_step: int | None
seed: int
initial_seed: int

def __init__(
self,
enable,
debug_mode=DebugMode.CHECK_NAN_INF_AND_ABORT,
output_dir=None,
checked_op_list=None,
skipped_op_list=None,
debug_step=None,
stack_height_limit=1,
enable: bool,
debug_mode: DebugMode = DebugMode.CHECK_NAN_INF_AND_ABORT,
output_dir: str | None = None,
checked_op_list: Sequence[str] | None = None,
skipped_op_list: Sequence[str] | None = None,
debug_step: Sequence[int] | None = None,
stack_height_limit: int = 1,
):
self.enable = enable
self.debug_mode = debug_mode
Expand Down Expand Up @@ -264,7 +294,7 @@ def __init__(
if self.enable:
self._set_seed(self.enable)

def _set_seed(self, flag):
def _set_seed(self, flag: int) -> None:
if self.initial_seed != self.seed:
self.seed = self.initial_seed

Expand All @@ -288,7 +318,7 @@ def _set_seed(self, flag):
flag,
)

def _set_env(self, check_flag):
def _set_env(self, check_flag: int) -> None:
paddle.set_flags({"FLAGS_check_nan_inf": check_flag})
if check_flag:
# set debug level
Expand All @@ -308,7 +338,7 @@ def _set_env(self, check_flag):
else:
raise ValueError("stack_height_limit must be int")

def update_and_check_step_id(self):
def update_and_check_step_id(self) -> bool:
if self.enable:
if self.start_step is not None and self.end_step is not None:
if (
Expand All @@ -321,17 +351,20 @@ def update_and_check_step_id(self):
return True
return False

def start_check_nan_inf(self):
def start_check_nan_inf(self) -> None:
if self.enable:
self._set_env(self.enable)

def stop_check_nan_inf(self):
def stop_check_nan_inf(self) -> None:
self._set_env(False)


def check_numerics(
tensor, op_type, var_name, debug_mode=DebugMode.CHECK_NAN_INF_AND_ABORT
):
tensor: Tensor,
op_type: str,
var_name: str,
debug_mode: DebugMode = DebugMode.CHECK_NAN_INF_AND_ABORT,
) -> tuple[Tensor, Tensor]:
"""
This function is used to debugging a tensor, finding the number of NaNs, Infs and zeros in the tensor.

Expand Down Expand Up @@ -397,12 +430,12 @@ def check_numerics(
return stats, values


def _get_operator_stats_flag():
def _get_operator_stats_flag() -> Any:
flags = paddle.get_flags(["FLAGS_low_precision_op_list"])
return flags["FLAGS_low_precision_op_list"]


def _print_operator_stats(op_count_dict):
def _print_operator_stats(op_count_dict: dict[str, str | list[int]]) -> None:
"""
Parse and print the stats of operators, mainly including the calls of
dtypes such as different fp32, fp16, bf16 and others.
Expand Down Expand Up @@ -446,7 +479,7 @@ def _print_operator_stats(op_count_dict):
print("<{:-^120}>\n".format(" op count: " + str(total_ops) + " "))


def enable_operator_stats_collection():
def enable_operator_stats_collection() -> None:
"""
Enable to collect the number of operators for different data types.
The statistical data are categorized according to four data types, namely
Expand Down Expand Up @@ -484,7 +517,7 @@ def enable_operator_stats_collection():
paddle.set_flags({'FLAGS_low_precision_op_list': 1})


def disable_operator_stats_collection():
def disable_operator_stats_collection() -> None:
"""
Disable the collection the number of operators for different data types.
This function is used in pair with the corresponding enable function.
Expand Down Expand Up @@ -525,7 +558,7 @@ def disable_operator_stats_collection():


@contextlib.contextmanager
def collect_operator_stats():
def collect_operator_stats() -> Generator[None, None, None]:
"""
The context switcher to enable to collect the number of operators for
different data types. The statistical data are categorized according
Expand Down Expand Up @@ -561,12 +594,12 @@ def collect_operator_stats():


def compare_accuracy(
dump_path,
another_dump_path,
output_filename,
loss_scale=1,
dump_all_tensors=False,
):
dump_path: str,
another_dump_path: str,
output_filename: str,
loss_scale: float = 1,
dump_all_tensors: bool = False,
) -> None:
r"""
This is a precision comparison tool that can be used to compare log data of float16 and float32.

Expand Down Expand Up @@ -619,7 +652,7 @@ def compare_accuracy(
)


def enable_tensor_checker(checker_config):
def enable_tensor_checker(checker_config: TensorCheckerConfig) -> None:
"""
The enable_tensor_checker(checker_config) function enables model-level accuracy checking and is used in combination with disables_tensor_checker() to achieve model-level precision checking by checking the output Tensors of all operators within the specified range.

Expand Down Expand Up @@ -660,7 +693,7 @@ def enable_tensor_checker(checker_config):
checker_config.stop_check_nan_inf()


def disable_tensor_checker():
def disable_tensor_checker() -> None:
"""
disable_tensor_checker() is used to disable accuracy checking, and is used together with enable_tensor_checker(config) to achieve model-level precision checking by checking the output Tensors of all operators within the specified range.

Expand Down