Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions docs/zh/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -795,3 +795,49 @@ solver = ppsci.solver.Solver(
!!! info "影响说明"

个别多任务学习方法(如weight based method)可能会改变**训练过程**中损失函数的计算方式,但仅限于影响训练过程,模型**评估过程**的损失计算方式保持不变。

## 3. 使用 Nsight 进行性能分析

Nsight是NVIDIA面相开发者提供的开发工具套件,能提供深入的跟踪、调试、评测和分析,以优化跨 NVIDIA GPU和CPU的复杂计算应用程序。详细文档可参考:[Nsight Systems Document](https://docs.nvidia.com/nsight-systems/index.html)

PaddleScience 初步支持使用 Nsight 进行性能分析,以 linux 开发环境 + laplace2d 案例为例,按照如下步骤使用 nsight 工具生成性能分析报告并查看分析结果。

1. 安装 nsight-system

开发机上下载 linux nsight-system 软件:nsight-systems/2023.4.1,并将 nsight 添加到环境变量 `PATH` 中:

执行:`PATH=/path/to/nsight-systems/2023.4.1/bin:$PATH`,同时在 windows 机器上安装**相同版本**的 nsight-system 软件。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mac机器和带图形界面机器的Linux系统还不能支持么

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mac机器和带图形界面机器的Linux系统还不能支持么

暂时不支持Mac(没机器没试过),带图形界面的Linux如果能安装nsight,那步骤是一样的,这里只是因为我自己用的是windows系统的linux nsight-system


2. 用 nsys 命令运行程序,生成性能分析文件

``` sh
{==NVTX=1 nsys profile -t cuda,nvtx --stats=true -o==} {++laplace2d++} python laplace2d.py
```

3. 查看分析结果

程序结束后,在终端内会打印出性能分析数据(如下所示),同时在上述 `-o` 参数指定的相对文件路径生成 `{++laplace2d++}.nsys-rep` 和 `{++laplace2d++}.sqlite` 两个文件。

在 windows 上使用 NVIDIA Nsight Systems 软件打开 `laplace2d.nsys-rep`,即可在图形化的界面上查看性能分析数据。

``` log
...
...
Only run 25 steps when 'NVTX' is set in environment for nsight analysis. Exit now ......

Generating '/tmp/nsys-report-18e4.qdstrm'
[1/7] [========================100%] laplace2d.nsys-rep
[2/7] [========================100%] laplace2d.sqlite
[3/7] Executing 'nvtx_sum' stats report

Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Style Range
-------- --------------- --------- ------------- ------------- ----------- ----------- ------------- ------- ------------------------------------
15.1 794,212,341 25 31,768,493.6 5,446,410.0 5,328,471 661,841,104 131,265,333.9 PushPop Loss computation
14.5 766,452,142 25 30,658,085.7 4,369,873.0 4,281,927 659,795,434 131,070,475.4 PushPop Constraint EQ
13.0 687,324,359 1,300 528,711.0 32,567.5 21,218 641,625,892 17,794,532.4 PushPop matmul dygraph
12.9 678,475,194 1 678,475,194.0 678,475,194.0 678,475,194 678,475,194 0.0 PushPop Training iteration 1
12.8 673,614,062 1,300 518,164.7 19,802.5 14,499 641,525,121 17,792,027.2 PushPop matmul compute
3.9 203,945,648 25 8,157,825.9 8,029,819.0 7,797,185 9,119,496 359,173.3 PushPop Loss backward
...
...
```
92 changes: 46 additions & 46 deletions ppsci/solver/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,59 +30,59 @@


def update_train_loss(
trainer: "solver.Solver", loss_dict: Dict[str, float], batch_size: int
solver: "solver.Solver", loss_dict: Dict[str, float], batch_size: int
):
for key in loss_dict:
if key not in trainer.train_output_info:
trainer.train_output_info[key] = misc.AverageMeter(key, "7.5f")
trainer.train_output_info[key].update(float(loss_dict[key]), batch_size)
if key not in trainer.train_loss_info:
trainer.train_loss_info[key] = misc.AverageMeter(key, ".5f")
trainer.train_loss_info[key].update(float(loss_dict[key]))
if key not in solver.train_output_info:
solver.train_output_info[key] = misc.AverageMeter(key, "7.5f")
solver.train_output_info[key].update(float(loss_dict[key]), batch_size)
if key not in solver.train_loss_info:
solver.train_loss_info[key] = misc.AverageMeter(key, ".5f")
solver.train_loss_info[key].update(float(loss_dict[key]))


def update_eval_loss(
trainer: "solver.Solver", loss_dict: Dict[str, float], batch_size: int
solver: "solver.Solver", loss_dict: Dict[str, float], batch_size: int
):
for key in loss_dict:
if key not in trainer.eval_output_info:
trainer.eval_output_info[key] = misc.AverageMeter(key, "7.5f")
trainer.eval_output_info[key].update(float(loss_dict[key]), batch_size)
if key not in solver.eval_output_info:
solver.eval_output_info[key] = misc.AverageMeter(key, "7.5f")
solver.eval_output_info[key].update(float(loss_dict[key]), batch_size)


def log_train_info(
trainer: "solver.Solver", batch_size: int, epoch_id: int, iter_id: int
solver: "solver.Solver", batch_size: int, epoch_id: int, iter_id: int
):
lr_msg = f"lr: {trainer.optimizer.get_lr():.5f}"
lr_msg = f"lr: {solver.optimizer.get_lr():.5f}"

metric_msg = ", ".join(
[
f"{key}: {trainer.train_output_info[key].avg:.5f}"
for key in trainer.train_output_info
f"{key}: {solver.train_output_info[key].avg:.5f}"
for key in solver.train_output_info
]
)

time_msg = ", ".join(
[trainer.train_time_info[key].mean for key in trainer.train_time_info]
[solver.train_time_info[key].mean for key in solver.train_time_info]
)

ips_msg = f"ips: {batch_size / trainer.train_time_info['batch_cost'].avg:.2f}"
if trainer.benchmark_flag:
ips_msg = f"ips: {batch_size / solver.train_time_info['batch_cost'].avg:.2f}"
if solver.benchmark_flag:
ips_msg += " samples/s"

eta_sec = (
(trainer.epochs - epoch_id + 1) * trainer.iters_per_epoch - iter_id
) * trainer.train_time_info["batch_cost"].avg
(solver.epochs - epoch_id + 1) * solver.iters_per_epoch - iter_id
) * solver.train_time_info["batch_cost"].avg
eta_msg = f"eta: {str(datetime.timedelta(seconds=int(eta_sec)))}"

epoch_width = len(str(trainer.epochs))
iters_width = len(str(trainer.iters_per_epoch))
epoch_width = len(str(solver.epochs))
iters_width = len(str(solver.iters_per_epoch))
log_str = (
f"[Train][Epoch {epoch_id:>{epoch_width}}/{trainer.epochs}]"
f"[Iter {iter_id:>{iters_width}}/{trainer.iters_per_epoch}] {lr_msg}, "
f"[Train][Epoch {epoch_id:>{epoch_width}}/{solver.epochs}]"
f"[Iter {iter_id:>{iters_width}}/{solver.iters_per_epoch}] {lr_msg}, "
f"{metric_msg}, {time_msg}, {ips_msg}, {eta_msg}"
)
if trainer.benchmark_flag:
if solver.benchmark_flag:
max_mem_reserved_msg = (
f"max_mem_reserved: {device.cuda.max_memory_reserved() // (1 << 20)} MB"
)
Expand All @@ -94,57 +94,57 @@ def log_train_info(

logger.scalar(
{
"train/lr": trainer.optimizer.get_lr(),
"train/lr": solver.optimizer.get_lr(),
**{
f"train/{key}": trainer.train_output_info[key].avg
for key in trainer.train_output_info
f"train/{key}": solver.train_output_info[key].avg
for key in solver.train_output_info
},
},
step=trainer.global_step,
vdl_writer=trainer.vdl_writer,
wandb_writer=trainer.wandb_writer,
tbd_writer=trainer.tbd_writer,
step=solver.global_step,
vdl_writer=solver.vdl_writer,
wandb_writer=solver.wandb_writer,
tbd_writer=solver.tbd_writer,
)


def log_eval_info(
trainer: "solver.Solver",
solver: "solver.Solver",
batch_size: int,
epoch_id: int,
iters_per_epoch: int,
iter_id: int,
):
metric_msg = ", ".join(
[
f"{key}: {trainer.eval_output_info[key].avg:.5f}"
for key in trainer.eval_output_info
f"{key}: {solver.eval_output_info[key].avg:.5f}"
for key in solver.eval_output_info
]
)

time_msg = ", ".join(
[trainer.eval_time_info[key].mean for key in trainer.eval_time_info]
[solver.eval_time_info[key].mean for key in solver.eval_time_info]
)

ips_msg = f"ips: {batch_size / trainer.eval_time_info['batch_cost'].avg:.2f}"
ips_msg = f"ips: {batch_size / solver.eval_time_info['batch_cost'].avg:.2f}"

eta_sec = (iters_per_epoch - iter_id) * trainer.eval_time_info["batch_cost"].avg
eta_sec = (iters_per_epoch - iter_id) * solver.eval_time_info["batch_cost"].avg
eta_msg = f"eta: {str(datetime.timedelta(seconds=int(eta_sec)))}"

epoch_width = len(str(trainer.epochs))
epoch_width = len(str(solver.epochs))
iters_width = len(str(iters_per_epoch))
logger.info(
f"[Eval][Epoch {epoch_id:>{epoch_width}}/{trainer.epochs}]"
f"[Eval][Epoch {epoch_id:>{epoch_width}}/{solver.epochs}]"
f"[Iter {iter_id:>{iters_width}}/{iters_per_epoch}] "
f"{metric_msg}, {time_msg}, {ips_msg}, {eta_msg}"
)

logger.scalar(
{
f"eval/{key}": trainer.eval_output_info[key].avg
for key in trainer.eval_output_info
f"eval/{key}": solver.eval_output_info[key].avg
for key in solver.eval_output_info
},
step=trainer.global_step,
vdl_writer=trainer.vdl_writer,
wandb_writer=trainer.wandb_writer,
tbd_writer=trainer.tbd_writer,
step=solver.global_step,
vdl_writer=solver.vdl_writer,
wandb_writer=solver.wandb_writer,
tbd_writer=solver.tbd_writer,
)
13 changes: 13 additions & 0 deletions ppsci/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import contextlib
import functools
import importlib
import itertools
import os
Expand All @@ -39,6 +40,7 @@
from paddle import nn
from paddle import optimizer as optim
from paddle.distributed import fleet
from paddle.framework import core
from paddle.static import InputSpec
from typing_extensions import Literal

Expand Down Expand Up @@ -444,11 +446,19 @@ def convert_expr(
# set up benchmark flag, will print memory stat if enabled
self.benchmark_flag: bool = os.getenv("BENCHMARK_ROOT", None) is not None

# set up nvtx flag for nsight analysis
self.nvtx_flag: bool = os.getenv("NVTX", None) is not None
self.forward_helper.nvtx_flag = self.nvtx_flag

def train(self) -> None:
"""Training."""
self.global_step = self.best_metric["epoch"] * self.iters_per_epoch
start_epoch = self.best_metric["epoch"] + 1

if self.nvtx_flag:
core.nvprof_start()
core.nvprof_enable_record_event()

for epoch_id in range(start_epoch, self.epochs + 1):
self.train_epoch_func(self, epoch_id, self.log_freq)
self.train_output_info.clear()
Expand Down Expand Up @@ -764,6 +774,7 @@ def export(
)
logger.message(f"ONNX model has been exported to: {export_path}.onnx")

@functools.lru_cache()
def autocast_context_manager(
self, enable: bool, level: Literal["O0", "O1", "O2", "OD"] = "O1"
) -> contextlib.AbstractContextManager:
Expand All @@ -786,6 +797,7 @@ def autocast_context_manager(
)
return ctx_manager

@functools.lru_cache()
def no_grad_context_manager(
self, enable: bool
) -> contextlib.AbstractContextManager:
Expand All @@ -807,6 +819,7 @@ def no_grad_context_manager(
)
return ctx_manager

@functools.lru_cache()
def no_sync_context_manager(
self,
enable: bool,
Expand Down
45 changes: 45 additions & 0 deletions ppsci/solver/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

from __future__ import annotations

import sys
import time
from typing import TYPE_CHECKING

import paddle
from paddle.distributed.fleet.utils import hybrid_parallel_util as hpu
from paddle.framework import core

from ppsci.solver import printer
from ppsci.utils import misc
Expand All @@ -38,6 +40,11 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
batch_tic = time.perf_counter()

for iter_id in range(1, solver.iters_per_epoch + 1):
if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_push(
f"Training iteration {solver.global_step + 1}"
) # Training iteration

total_loss = 0.0
total_batch_size = 0
reader_cost = 0.0
Expand Down Expand Up @@ -77,6 +84,9 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
# forward for every constraint, including model and equation expression
with solver.no_sync_context_manager(solver.world_size > 1, solver.model):
with solver.autocast_context_manager(solver.use_amp, solver.amp_level):
if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_push("Loss computation")

constraint_losses = solver.forward_helper.train_forward(
tuple(
_constraint.output_expr
Expand All @@ -88,17 +98,31 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
label_dicts,
weight_dicts,
)

if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_pop() # Loss computation

# accumulate all losses
if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_push("Loss aggregator")

for i, _constraint in enumerate(solver.constraint.values()):
total_loss += constraint_losses[i]
loss_dict[_constraint.name] += (
float(constraint_losses[i]) / solver.update_freq
)
if solver.update_freq > 1:
total_loss = total_loss / solver.update_freq

if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_pop() # Loss aggregator

loss_dict["loss"] = float(total_loss)

# backward
if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_push("Loss backward")

if solver.loss_aggregator is None:
if solver.use_amp:
total_loss_scaled = solver.scaler.scale(total_loss)
Expand All @@ -108,8 +132,14 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
else:
solver.loss_aggregator(constraint_losses, solver.global_step).backward()

if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_pop() # Loss backward

# update parameters
if iter_id % solver.update_freq == 0 or iter_id == solver.iters_per_epoch:
if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_push("Optimizer update")

if solver.world_size > 1:
# fuse + allreduce manually before optimization if use DDP + no_sync
# details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622
Expand All @@ -118,6 +148,10 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
solver.scaler.minimize(solver.optimizer, total_loss_scaled)
else:
solver.optimizer.step()

if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_pop() # Optimizer update

solver.optimizer.clear_grad()

# update learning rate by step
Expand All @@ -138,6 +172,17 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):

batch_tic = time.perf_counter()

if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_pop() # Training iteration
NVTX_STOP_ITER = 25
if solver.global_step >= NVTX_STOP_ITER:
print(
f"Only run {NVTX_STOP_ITER} steps when 'NVTX' is set in environment"
" for nsight analysis. Exit now ......\n"
)
core.nvprof_stop()
sys.exit(0)


def train_LBFGS_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
"""Train function for one epoch with L-BFGS optimizer.
Expand Down
Loading