Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enh] Refactor sum aggregator #834

Merged
1 change: 1 addition & 0 deletions docs/zh/api/loss/mtl.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
- LossAggregator
- PCGrad
- Relobralo
- Sum
show_root_heading: true
heading_level: 3
4 changes: 1 addition & 3 deletions docs/zh/examples/viv.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
=== "模型评估命令"

``` sh
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdeqn
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams
python viv.py mode=eval EVAL.pretrained_model_path=./viv_pretrained
python viv.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams
```

| 预训练模型 | 指标 |
Expand Down
8 changes: 4 additions & 4 deletions ppsci/geometry/timedomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def random_points(

Args:
n (int): The total number of random points to generate.
random (string): Specifies the way to generate random points, default is "pseudo" , which means that a pseudo-random number generator is used.
random (str): Specifies the way to generate random points, default is "pseudo" , which means that a pseudo-random number generator is used.
criteria (Optional[Callable]): A method that filters on the generated random points, defualt is None.

Returns:
Expand Down Expand Up @@ -432,7 +432,7 @@ def random_boundary_points(

Args:
n (int): The total number of spatial-temporal points generated on a given geometry boundary.
random (string): Controls the way to generate random points. Default is "pseudo".
random (str): Controls the way to generate random points. Default is "pseudo".
criteria (Optional[Callable]): Used to filter the generated boundary points, only points that meet certain conditions are retained. Default is None.

Returns:
Expand Down Expand Up @@ -650,7 +650,7 @@ def random_initial_points(self, n: int, random: str = "pseudo"):

Args:
n (int): The total number of generated points.
random (string): Controls the way to generate random points. Default is "pseudo".
random (str): Controls the way to generate random points. Default is "pseudo".

Returns:
np.ndarray: A set of point coordinates randomly distributed on the spatial-temporal domain at the initial moment.
Expand Down Expand Up @@ -709,7 +709,7 @@ def sample_initial_interior(

Args:
n (int): The total number of interior points generated.
random (string): The method used to specify the initial point of generation. Default is "pseudo".
random (str): The method used to specify the initial point of generation. Default is "pseudo".
criteria (Optional[Callable]): Used to filter the generated interior points, only points that meet certain conditions are retained. Default is None.
evenly (bool): Indicates whether the initial points are generated evenly. Default is False.
compute_sdf_derivatives (bool): Indicates whether to calculate the derivative of signed distance function or not. Default is False.
Expand Down
2 changes: 2 additions & 0 deletions ppsci/loss/mtl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from ppsci.loss.mtl.base import LossAggregator
from ppsci.loss.mtl.pcgrad import PCGrad
from ppsci.loss.mtl.relobralo import Relobralo
from ppsci.loss.mtl.sum import Sum

__all__ = [
"AGDA",
"LossAggregator",
"PCGrad",
"Relobralo",
"Sum",
]


Expand Down
2 changes: 1 addition & 1 deletion ppsci/loss/mtl/agda.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, model: nn.Layer, M: int = 100, gamma: float = 0.999) -> None:
self.Lf_tilde_acc = 0.0
self.Lu_tilde_acc = 0.0

def __call__(self, losses, step: int = 0):
def __call__(self, losses, step: int = 0) -> "AGDA":
if len(losses) != 2:
raise ValueError(
f"Number of losses(tasks) for AGDA shoule be 2, but got {len(losses)}"
Expand Down
2 changes: 1 addition & 1 deletion ppsci/loss/mtl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, model: nn.Layer) -> None:
if not param.stop_gradient:
self.param_num += 1

def __call__(self, losses, step: int = 0):
def __call__(self, losses, step: int = 0) -> "LossAggregator":
self.losses = losses
self.loss_num = len(losses)
self.step = step
Expand Down
18 changes: 8 additions & 10 deletions ppsci/loss/mtl/relobralo.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,28 @@ def __init__(
self.register_buffer("losses_prev", paddle.zeros([self.num_losses]))
self.register_buffer("lmbda", paddle.ones([self.num_losses]))

def _softmax(self, vec: paddle.Tensor) -> paddle.Tensor:
def _softmax(self, vec: "paddle.Tensor") -> "paddle.Tensor":
max_item = vec.max()
result = paddle.exp(vec - max_item) / paddle.exp(vec - max_item).sum()
return result

def _compute_bal(
self, losses_vec1: paddle.Tensor, losses_vec2: paddle.Tensor
) -> paddle.Tensor:
self, losses_vec1: "paddle.Tensor", losses_vec2: "paddle.Tensor"
) -> "paddle.Tensor":
return self.num_losses * (
self._softmax(losses_vec1 / (self.tau * losses_vec2 + self.eps))
)

def __call__(self, losses: List[paddle.Tensor], step: int = 0) -> "Relobralo":
self.step = step
def __call__(self, losses: List["paddle.Tensor"], step: int = 0) -> "paddle.Tensor":
assert len(losses) == self.num_losses, (
f"Length of given losses({len(losses)}) should be equal to "
f"num_losses({self.num_losses})."
)
self.step = step
losses_stacked = paddle.stack(losses) # [num_losses, ]

if self.step == 0:
self.loss = losses_stacked.sum()
loss = losses_stacked.sum()
with paddle.no_grad():
paddle.assign(losses_stacked.detach(), self.losses_init)
else:
Expand All @@ -110,12 +110,10 @@ def __call__(self, losses: List[paddle.Tensor], step: int = 0) -> "Relobralo":
)

# 3. compute reweighted total loss with lambda
self.loss = (losses_stacked * self.lmbda).sum()
loss = (losses_stacked * self.lmbda).sum()

# update losses_prev at the end of each step
with paddle.no_grad():
paddle.assign(losses_stacked.detach(), self.losses_prev)
return self

def backward(self) -> None:
self.loss.backward()
return loss
50 changes: 50 additions & 0 deletions ppsci/loss/mtl/sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Sequence

if TYPE_CHECKING:
import paddle

from ppsci.loss.mtl.base import LossAggregator


class Sum(LossAggregator):
r"""
**Default loss aggregator** which do simple summation for given losses as below.

$$
loss = \sum_i^N losses_i
$$
"""

def __init__(self) -> None:
self.step = 0

def __call__(
self, losses: Sequence["paddle.Tensor"], step: int = 0
) -> paddle.Tensor:
assert (
len(losses) > 0
), f"Number of given losses({len(losses)}) can not be empty."
self.step = step

loss = 0.0
for i in range(len(losses)):
loss += losses[i]

return loss
19 changes: 15 additions & 4 deletions ppsci/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ def __init__(

# choosing an appropriate training function for different optimizers
if misc.typename(self.optimizer) == "LBFGS":
if self.use_amp:
raise ValueError(
"Auto Mix Precision is not supported for L-BFGS optimizer."
)
self.train_epoch_func = ppsci.solver.train.train_LBFGS_epoch_func
if self.update_freq != 1:
self.update_freq = 1
Expand Down Expand Up @@ -398,8 +402,13 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:
jit.enable_to_static(to_static)
logger.info(f"Set to_static={to_static} for computational optimization.")

# use loss aggregator, use summation if None
self.loss_aggregator = loss_aggregator
# use loss aggregator, use Sum if None
if isinstance(loss_aggregator, (mtl.AGDA, mtl.PCGrad)) and self.use_amp:
raise ValueError(
"Auto Mix Precision do not support AGDA, PCGrad loss aggregator yet, "
"please set use_amp=False."
)
self.loss_aggregator = loss_aggregator or mtl.Sum()

# convert sympy to callable object if exist
extra_parameters = []
Expand Down Expand Up @@ -432,6 +441,10 @@ def convert_expr(
for name in container.output_expr:
if isinstance(container.output_expr[name], sp.Basic):
container.output_expr[name] = funcs[ind]
if self.world_size > 1:
container.output_expr[name] = dist_wrapper(
container.output_expr[name]
)
ind += 1

if self.constraint:
Expand Down Expand Up @@ -775,7 +788,6 @@ def export(
)
logger.message(f"ONNX model has been exported to: {export_path}.onnx")

@functools.lru_cache()
def autocast_context_manager(
self, enable: bool, level: Literal["O0", "O1", "O2", "OD"] = "O1"
) -> contextlib.AbstractContextManager:
Expand Down Expand Up @@ -820,7 +832,6 @@ def no_grad_context_manager(
)
return ctx_manager

@functools.lru_cache()
def no_sync_context_manager(
self,
enable: bool,
Expand Down
39 changes: 17 additions & 22 deletions ppsci/solver/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
f"Training iteration {solver.global_step + 1}"
) # Training iteration

total_loss = 0.0
total_batch_size = 0
reader_cost = 0.0
batch_cost = 0.0
Expand Down Expand Up @@ -106,31 +105,30 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_push("Loss aggregator")

total_loss = solver.loss_aggregator(
constraint_losses, solver.global_step
)
if solver.update_freq > 1:
total_loss = total_loss / solver.update_freq

for i, _constraint in enumerate(solver.constraint.values()):
total_loss += constraint_losses[i]
loss_dict[_constraint.name] += (
loss_dict[_constraint.name] = (
float(constraint_losses[i]) / solver.update_freq
)
if solver.update_freq > 1:
total_loss = total_loss / solver.update_freq
loss_dict["loss"] = float(total_loss)

if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_pop() # Loss aggregator

loss_dict["loss"] = float(total_loss)

# backward
if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_push("Loss backward")

if solver.loss_aggregator is None:
if solver.use_amp:
total_loss_scaled = solver.scaler.scale(total_loss)
total_loss_scaled.backward()
else:
total_loss.backward()
if solver.use_amp:
total_loss_scaled = solver.scaler.scale(total_loss)
total_loss_scaled.backward()
else:
solver.loss_aggregator(constraint_losses, solver.global_step).backward()
total_loss.backward()

if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_pop() # Loss backward
Expand Down Expand Up @@ -233,7 +231,6 @@ def closure() -> paddle.Tensor:
Returns:
paddle.Tensor: Computed loss scalar.
"""
total_loss = 0
with solver.no_sync_context_manager(solver.world_size > 1, solver.model):
with solver.autocast_context_manager(solver.use_amp, solver.amp_level):
# forward for every constraint, including model and equation expression
Expand All @@ -248,20 +245,18 @@ def closure() -> paddle.Tensor:
label_dicts,
weight_dicts,
)

total_loss = solver.loss_aggregator(
constraint_losses, solver.global_step
)
# accumulate all losses
for i, _constraint in enumerate(solver.constraint.values()):
total_loss += constraint_losses[i]
loss_dict[_constraint.name] = float(constraint_losses[i])
loss_dict["loss"] = float(total_loss)

# backward
solver.optimizer.clear_grad()
if solver.loss_aggregator is None:
total_loss.backward()
else:
solver.loss_aggregator(
constraint_losses, solver.global_step
).backward()
total_loss.backward()

if solver.world_size > 1:
# fuse + allreduce manually before optimization if use DDP model
Expand Down
17 changes: 17 additions & 0 deletions ppsci/utils/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,25 @@ def load_pretrain(
... path="path/to/pretrain_model") # doctest: +SKIP
"""
if path.startswith("http"):
# download from path(url) and get its' physical path
eqn_path = path.replace(".pdparams", ".pdeq", 1)
path = download.get_weights_path_from_url(path)

# automatically download additional equation weights if avaiable
def is_url_accessible(url: str):
try:
import requests

response = requests.head(url, timeout=5)
return response.status_code == requests.codes.ok
except requests.RequestException:
return False
except Exception:
return False

if is_url_accessible(eqn_path):
download.get_weights_path_from_url(eqn_path)

# remove ".pdparams" in suffix of path for convenient
if path.endswith(".pdparams"):
path = path[:-9]
Expand Down